diff --git a/tools/bazel.rc b/.bazelrc
similarity index 95%
rename from tools/bazel.rc
rename to .bazelrc
index 1fdf51f53e29c7111cf89c016400b710051cf9c6..cd7e13ddfc146208f79be900917b05b694869d72 100644
--- a/tools/bazel.rc
+++ b/.bazelrc
@@ -76,7 +76,6 @@ build:nonccl --define=no_nccl_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
-build --define=grpc_no_ares=true
build --spawn_strategy=standalone
build --genrule_strategy=standalone
@@ -93,3 +92,11 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build --define=PREFIX=/usr
build --define=LIBDIR=$(PREFIX)/lib
build --define=INCLUDEDIR=$(PREFIX)/include
+
+# Default options should come above this line
+
+# Options from ./configure
+try-import %workspace%/.tf_configure.bazelrc
+
+# Put user-specific options in .bazelrc.user
+try-import %workspace%/.bazelrc.user
diff --git a/.gitignore b/.gitignore
index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,7 @@
.DS_Store
.ipynb_checkpoints
node_modules
-/.bazelrc
+/.bazelrc.user
/.tf_configure.bazelrc
/bazel-*
/bazel_pip
diff --git a/CODEOWNERS b/CODEOWNERS
index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,7 +1,7 @@
# Where component owners are known, add them here.
/tenosrflow/core/debug @caisq
-/tensorflow/core/nccl/ @azaks @csigg
+/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
@@ -51,13 +51,13 @@
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden
/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
-/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl
+/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
/tensorflow/contrib/stateless/ @girving @alextp
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
-/tensorflow/contrib/tensorrt/ @aaroey
+/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
# NEED OWNER: /tensorflow/contrib/testing/
/tensorflow/contrib/timeseries/ @allenlavoie
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
diff --git a/README.md b/README.md
index 044174947a094d43a51f7140dd40ec0f17801d40..519815d006cc33be10132909baf414a4bd843435 100644
--- a/README.md
+++ b/README.md
@@ -113,11 +113,12 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
-**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
-**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
-**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
+**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
+**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
+**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
+**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
-**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl)
+**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl)
## For more information
diff --git a/RELEASE.md b/RELEASE.md
index b13b071bd6cf4d3a260c8e248a67d23e1a688498..32abdcea497618918964174a661a6ba872598f65 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -7,6 +7,8 @@
Serving.
* Keras models now support evaluating with a `tf.data.Dataset`.
* TensorFlow binaries are built with XLA support linked in by default.
+* Ignite Dataset added to contrib/ignite that allows to work with Apache
+ Ignite.
## Bug Fixes and Other Changes
diff --git a/WORKSPACE b/WORKSPACE
index 0c7bc085b512b084b9470abe17326d7c119aa327..7057d3f149e766cd2983ecc89509f84c37075602 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,5 +1,7 @@
workspace(name = "org_tensorflow")
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
http_archive(
name = "io_bazel_rules_closure",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
@@ -14,30 +16,27 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
-http_archive(
- name = "base_images_docker",
- sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9",
- strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6",
- urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"],
-)
+load("//third_party/toolchains/preconfig/generate:archives.bzl",
+ "bazel_toolchains_archive")
-http_archive(
- name = "bazel_toolchains",
- sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb",
- strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b",
- urls = [
- "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz",
- ],
+bazel_toolchains_archive()
+
+load(
+ "@bazel_toolchains//repositories:repositories.bzl",
+ bazel_toolchains_repositories = "repositories",
)
-http_archive(
- name = "io_bazel_rules_docker",
- sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd",
- strip_prefix = "rules_docker-0.5.1",
- urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"],
+bazel_toolchains_repositories()
+
+load(
+ "@io_bazel_rules_docker//container:container.bzl",
+ container_repositories = "repositories",
)
-load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace")
+container_repositories()
+
+load("//third_party/toolchains/preconfig/generate:workspace.bzl",
+ "remote_config_workspace")
remote_config_workspace()
@@ -45,7 +44,7 @@ remote_config_workspace()
# files, in case the parsing of those build files depends on the bazel
# version we require here.
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
-check_bazel_version_at_least("0.15.0")
+check_bazel_version_at_least("0.18.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
@@ -57,9 +56,9 @@ android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
-new_http_archive(
+http_archive(
name = "inception_v1",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
@@ -67,9 +66,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_ssd",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
@@ -77,9 +76,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_multibox",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
@@ -87,9 +86,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "stylize",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
@@ -97,9 +96,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "speech_commands",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
diff --git a/configure.py b/configure.py
index f087da002d534e1f0f4c1598e87217168c892dbe..1e732db26404906901a9eeab97a5e75137ee8388 100644
--- a/configure.py
+++ b/configure.py
@@ -255,18 +255,6 @@ def setup_python(environ_cp):
def reset_tf_configure_bazelrc():
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
- bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc')
-
- data = []
- if os.path.exists(bazelrc_path):
- with open(bazelrc_path, 'r') as f:
- data = f.read().splitlines()
- with open(bazelrc_path, 'w') as f:
- for l in data:
- if _TF_BAZELRC_FILENAME in l:
- continue
- f.write('%s\n' % l)
- f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME)
def cleanup_makefile():
"""Delete any leftover BUILD files from the Makefile build.
@@ -452,11 +440,12 @@ def convert_version_to_int(version):
return int(version_str)
-def check_bazel_version(min_version):
- """Check installed bazel version is at least min_version.
+def check_bazel_version(min_version, max_version):
+ """Check installed bazel version is between min_version and max_version.
Args:
min_version: string for minimum bazel version.
+ max_version: string for maximum bazel version.
Returns:
The bazel version detected.
@@ -474,6 +463,7 @@ def check_bazel_version(min_version):
min_version_int = convert_version_to_int(min_version)
curr_version_int = convert_version_to_int(curr_version)
+ max_version_int = convert_version_to_int(max_version)
# Check if current bazel version can be detected properly.
if not curr_version_int:
@@ -486,7 +476,12 @@ def check_bazel_version(min_version):
if curr_version_int < min_version_int:
print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version)
- sys.exit(0)
+ sys.exit(1)
+ if (curr_version_int > max_version_int and
+ 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ):
+ print('Please downgrade your bazel installation to version %s or lower to '
+ 'build TensorFlow!' % max_version)
+ sys.exit(1)
return curr_version
@@ -1559,11 +1554,9 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.15.0')
+ check_bazel_version('0.19.0', '0.20.0')
reset_tf_configure_bazelrc()
- # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
- write_to_bazelrc('import %workspace%/tools/bazel.rc')
cleanup_makefile()
setup_python(environ_cp)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index fd4b94202aad24a82abef8abd16431f61a8326f0..449a1372edb031c68786d8672e2a1499c2b3d047 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -267,6 +267,15 @@ config_setting(
visibility = ["//visibility:public"],
)
+# By default, XLA GPU is compiled into tensorflow when building with
+# --config=cuda even when `with_xla_support` is false. The config setting
+# here allows us to override the behavior if needed.
+config_setting(
+ name = "no_xla_deps_in_cuda",
+ define_values = {"no_xla_deps_in_cuda": "true"},
+ visibility = ["//visibility:public"],
+)
+
config_setting(
name = "with_gdr_support",
define_values = {"with_gdr_support": "true"},
@@ -606,9 +615,11 @@ py_library(
name = "tensorflow_py",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = [
+ deps = select({
+ "api_version_2": [],
+ "//conditions:default": ["//tensorflow/contrib:contrib_py"],
+ }) + [
":tensorflow_py_no_contrib",
- "//tensorflow/contrib:contrib_py",
"//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index f13623b0d57d3b59bb9455a46a9fab29fee25784..4eba763129a6aef40e3c130d56bf8ab19638b7ca 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -20,14 +20,14 @@ from __future__ import print_function as _print_function
import os as _os
+# API IMPORTS PLACEHOLDER
+
# pylint: disable=g-bad-import-order
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
@@ -35,8 +35,9 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: di
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
-# Calls to enable and disable features.
-enable_eager_execution() # pylint: disable=undefined-variable
+# Enable TF2 behaviors
+from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
+_compat.enable_v2_behavior()
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..21b5277614667bdbd7271ac3e57f5b69d5a19264 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -23,13 +23,13 @@ import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+# API IMPORTS PLACEHOLDER
+
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb..25df970ecab0757f23465ab19e7f45de0c759458 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -175,6 +175,34 @@ tf_cuda_library(
],
)
+tf_cuda_library(
+ name = "env",
+ srcs = [
+ "env.cc",
+ ],
+ hdrs = [
+ "env.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api",
+ ":tf_status_helper",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ "//tensorflow/core:platform_env",
+ "//tensorflow/core:lib",
+ ],
+ "//conditions:default": [
+ ":c_api",
+ ":tf_status_helper",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:platform_env",
+ "//tensorflow/core:lib",
+ ],
+ }) + [":c_api_internal"],
+)
+
tf_cuda_library(
name = "kernels",
srcs = [
@@ -188,10 +216,14 @@ tf_cuda_library(
deps = select({
"//tensorflow:android": [
":c_api",
+ ":c_api_internal",
+ ":tf_status_helper",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
+ ":c_api_internal",
+ ":tf_status_helper",
"//tensorflow/core:framework",
],
}),
@@ -330,6 +362,27 @@ tf_kernel_library(
alwayslink = 1,
)
+tf_cuda_cc_test(
+ name = "env_test",
+ size = "small",
+ srcs = ["env_test.cc"],
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
+ tags = ["noasan"],
+ # We must ensure that the dependencies can be dynamically linked since
+ # the shared library must be able to use core:framework.
+ # linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":c_api",
+ ":env",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "kernels_test",
size = "small",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index f13e8777dff164bcd8eedf46310ae846abd0c804..9580215a317b1a6b1cdacbd430a1764af61be990 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) {
namespace {
class TF_ManagedBuffer : public TensorBuffer {
public:
- void* data_;
- size_t len_;
- void (*deallocator_)(void* data, size_t len, void* arg);
- void* deallocator_arg_;
+ TF_ManagedBuffer(void* data, size_t len,
+ void (*deallocator)(void* data, size_t len, void* arg),
+ void* deallocator_arg)
+ : TensorBuffer(data),
+ len_(len),
+ deallocator_(deallocator),
+ deallocator_arg_(deallocator_arg) {}
+
+ const size_t len_;
+ void (*const deallocator_)(void* data, size_t len, void* arg);
+ void* const deallocator_arg_;
~TF_ManagedBuffer() override {
- (*deallocator_)(data_, len_, deallocator_arg_);
+ (*deallocator_)(data(), len_, deallocator_arg_);
}
- void* data() const override { return data_; }
size_t size() const override { return len_; }
TensorBuffer* root_buffer() override { return this; }
void FillAllocationDescription(AllocationDescription* proto) const override {
@@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
dimvec[i] = static_cast(dims[i]);
}
- TF_ManagedBuffer* buf = new TF_ManagedBuffer;
- buf->len_ = len;
+ TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) &&
reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
@@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
//
// Other types have the same representation, so copy only if it is safe to
// do so.
- buf->data_ = allocate_tensor("TF_NewTensor", len);
- std::memcpy(buf->data_, data, len);
- buf->deallocator_ = deallocate_buffer;
- buf->deallocator_arg_ = nullptr;
+ buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
+ deallocate_buffer, nullptr);
+ std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
- buf->data_ = data;
- buf->deallocator_ = deallocator;
- buf->deallocator_arg_ = deallocator_arg;
+ buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
}
+
TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
@@ -477,14 +480,15 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
CHECK_EQ(nelems, 0);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
- return TF_NewTensor(dtype, reinterpret_cast(dims.data()),
- shape.dims(), reinterpret_cast(&empty), 0,
- [](void*, size_t, void*) {}, nullptr);
+ return TF_NewTensor(
+ dtype, reinterpret_cast(dims.data()), shape.dims(),
+ reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr);
}
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) {
status->status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
@@ -1592,18 +1596,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
break; \
}
- LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
- for (int i = 0; i < attr->list().s_size();
- ++i) { metadata.total_size += attr->list().s(i).size(); });
+ LIST_CASE(
+ s, TF_ATTR_STRING, metadata.total_size = 0;
+ for (int i = 0; i < attr->list().s_size();
+ ++i) { metadata.total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
- LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
- for (int i = 0; i < attr->list().shape_size(); ++i) {
- const auto& s = attr->list().shape(i);
- metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
- });
+ LIST_CASE(
+ shape, TF_ATTR_SHAPE, metadata.total_size = 0;
+ for (int i = 0; i < attr->list().shape_size(); ++i) {
+ const auto& s = attr->list().shape(i);
+ metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
+ });
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -91,7 +91,7 @@ extern "C" {
// --------------------------------------------------------------------------
// TF_Version returns a string describing version information of the
// TensorFlow library. TensorFlow using semantic versioning.
-TF_CAPI_EXPORT extern const char* TF_Version();
+TF_CAPI_EXPORT extern const char* TF_Version(void);
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
@@ -157,7 +157,7 @@ typedef enum TF_Code {
typedef struct TF_Status TF_Status;
// Return a new status object.
-TF_CAPI_EXPORT extern TF_Status* TF_NewStatus();
+TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
@@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto,
size_t proto_len);
// Useful for passing *out* a protobuf.
-TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer();
+TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void);
TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*);
@@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
typedef struct TF_SessionOptions TF_SessionOptions;
// Return a new options object.
-TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions();
+TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void);
// Set the target in TF_SessionOptions.options.
// target can be empty, a single entry, or a comma separated list of entries.
@@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*);
typedef struct TF_Graph TF_Graph;
// Return a new graph object.
-TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph();
+TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void);
// Destroy an options object. Graph will be deleted once no more
// TFSession's are referencing it.
@@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph,
// TF_GraphImportGraphDef.
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
-TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions();
+TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(
+ void);
TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_ImportGraphDefOptions* opts);
@@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
//
// The data in the buffer will be the serialized OpList proto for ops registered
// in this address space.
-TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList();
+TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void);
// TF_ApiDefMap encapsulates a collection of API definitions for an operation.
//
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 3693cc85996365360253c8a94c29272a16e11e9a..81343f7bc027be82d28164be51011c794715d03a 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -66,7 +66,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
}
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
- unsigned char gpu_memory_allow_growth) {
+ unsigned char gpu_memory_allow_growth,
+ unsigned int num_cpu_devices) {
tensorflow::ConfigProto config;
auto* optimizer_options =
config.mutable_graph_options()->mutable_optimizer_options();
@@ -87,6 +88,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
auto* gpu_options = config.mutable_gpu_options();
gpu_options->set_allow_growth(gpu_memory_allow_growth);
+ (*config.mutable_device_count())["CPU"] = num_cpu_devices;
+
// TODO(b/113217601): This is needed for EagerContext::runner_ to use a
// threadpool, so that we avoid the possibility of running the runner_ in the
// threadpool of GPU event mgr, as that can trigger more callbacks to be
@@ -6530,7 +6533,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/cycle_length"
+ name: "ExperimentalParallelInterleaveDataset/cycle_length"
op: "Const"
attr {
key: "dtype"
@@ -6551,7 +6554,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/block_length"
+ name: "ExperimentalParallelInterleaveDataset/block_length"
op: "Const"
attr {
key: "dtype"
@@ -6572,7 +6575,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/sloppy"
+ name: "ExperimentalParallelInterleaveDataset/sloppy"
op: "Const"
attr {
key: "dtype"
@@ -6593,7 +6596,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/buffer_output_elements"
+ name: "ExperimentalParallelInterleaveDataset/buffer_output_elements"
op: "Const"
attr {
key: "dtype"
@@ -6614,7 +6617,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/prefetch_input_elements"
+ name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements"
op: "Const"
attr {
key: "dtype"
@@ -6635,14 +6638,14 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset"
- op: "ParallelInterleaveDataset"
+ name: "ExperimentalParallelInterleaveDataset"
+ op: "ExperimentalParallelInterleaveDataset"
input: "RepeatDataset:handle:0"
- input: "ParallelInterleaveDataset/cycle_length:output:0"
- input: "ParallelInterleaveDataset/block_length:output:0"
- input: "ParallelInterleaveDataset/sloppy:output:0"
- input: "ParallelInterleaveDataset/buffer_output_elements:output:0"
- input: "ParallelInterleaveDataset/prefetch_input_elements:output:0"
+ input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0"
+ input: "ExperimentalParallelInterleaveDataset/block_length:output:0"
+ input: "ExperimentalParallelInterleaveDataset/sloppy:output:0"
+ input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0"
+ input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0"
attr {
key: "Targuments"
value {
@@ -6742,7 +6745,7 @@ library {
node_def {
name: "ShuffleDataset_2"
op: "ShuffleDataset"
- input: "ParallelInterleaveDataset:handle:0"
+ input: "ExperimentalParallelInterleaveDataset:handle:0"
input: "ShuffleDataset_2/buffer_size_1:output:0"
input: "ShuffleDataset_2/seed_2:output:0"
input: "ShuffleDataset_2/seed2_2:output:0"
@@ -8535,8 +8538,9 @@ TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
// Reduce GPU memory allocation, and set appropriate config options for TFE
// context.
- auto* config =
- TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
+ auto* config = TF_CreateConfig(
+ /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 10);
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
if (!status->status.ok()) {
CHECK(!config);
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 80c8bfe594c4c89606efd01bec7f50e7a86b5bda..cb7a146846ff0bdac09f4a90765f78e0ada75718 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -67,9 +67,10 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
// `enable_xla_compilation` is non-zero, and OFF otherwise.
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
+// c) ConfigProto.device_count is set to `num_cpu_devices`.
TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
- unsigned char enable_xla_compilation,
- unsigned char gpu_memory_allow_growth);
+ unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth,
+ unsigned int num_cpu_devices);
// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level
// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE
@@ -239,7 +240,7 @@ TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
// Platform-specific implementation to return an unused port. (This should used
// in tests only.)
-TF_CAPI_EXPORT int TF_PickUnusedPortOrDie();
+TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void);
// Fast path method that makes constructing a single scalar tensor require less
// overhead and copies.
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 8d6c8d958d5961fce817156a14eb2b2940c1f2f0..120748ab763a3358b6e38e64bb3b6fd2ea32f7c3 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -48,7 +48,7 @@ extern "C" {
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
-TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
+TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void);
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
@@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
-// Returns the device of the operation that produced `h`.
-// If `h` was produced by a copy, returns the destination device of
-// the copy. Note that returned device name is not always the device
-// holding the tensor handle's memory. If you want the latter, use
-// TFE_TensorHandleBackingDeviceName.
-// This function will block till the operation that produces `h` has completed.
-//
-// Device on which the kernel of the operation that produced `h` ran.
-//
-// If `h` was produced by a copy, returns the destination device of
-// the copy.
-//
-// Note that returned device name is not always the device that owns the memory
-// that backs the tensor handle. For the latter see
-// TFE_TensorHandleBackingDeviceName.
-//
-// This function will block till the operation that produces `h` has completed.
+// Returns the device of the operation that produced `h`. If `h` was produced by
+// a copy, returns the destination device of the copy. Note that the returned
+// device name is not always the device holding the tensor handle's memory. If
+// you want the latter, use TFE_TensorHandleBackingDeviceName. This function
+// will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065
--- /dev/null
+++ b/tensorflow/c/env.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/env.h"
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+struct TF_StringStream {
+ std::vector<::tensorflow::string>* list;
+ size_t position;
+};
+
+void TF_CreateDir(const char* dirname, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->CreateDir(dirname));
+}
+
+void TF_DeleteDir(const char* dirname, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteDir(dirname));
+}
+
+void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count,
+ int64_t* undeleted_dir_count, TF_Status* status) {
+ ::tensorflow::int64 f, d;
+
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d));
+ *undeleted_file_count = f;
+ *undeleted_dir_count = d;
+}
+
+void TF_FileStat(const char* filename, TF_FileStatistics* stats,
+ TF_Status* status) {
+ ::tensorflow::FileStatistics cc_stats;
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Status s =
+ ::tensorflow::Env::Default()->Stat(filename, &cc_stats);
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+ if (s.ok()) {
+ stats->length = cc_stats.length;
+ stats->mtime_nsec = cc_stats.mtime_nsec;
+ stats->is_directory = cc_stats.is_directory;
+ }
+}
+
+void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle,
+ TF_Status* status) {
+ std::unique_ptr<::tensorflow::WritableFile> f;
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Status s =
+ ::tensorflow::Env::Default()->NewWritableFile(filename, &f);
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+
+ if (s.ok()) {
+ *handle = reinterpret_cast(f.release());
+ }
+}
+
+void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close());
+ delete cc_file;
+}
+
+void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync());
+}
+
+void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush());
+}
+
+void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data,
+ size_t length, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, cc_file->Append(::tensorflow::StringPiece{data, length}));
+}
+
+void TF_DeleteFile(const char* filename, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteFile(filename));
+}
+
+bool TF_StringStreamNext(TF_StringStream* list, const char** result) {
+ if (list->position >= list->list->size()) {
+ *result = nullptr;
+ return false;
+ }
+
+ *result = list->list->at(list->position++).c_str();
+ return true;
+}
+
+void TF_StringStreamDone(TF_StringStream* list) {
+ delete list->list;
+ delete list;
+}
+TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) {
+ auto* children = new std::vector<::tensorflow::string>;
+
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->GetChildren(dirname, children));
+
+ auto* list = new TF_StringStream;
+ list->list = children;
+ list->position = 0;
+ return list;
+}
+
+TF_StringStream* TF_GetLocalTempDirectories() {
+ auto* tmpdirs = new std::vector<::tensorflow::string>;
+
+ ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs);
+
+ auto* list = new TF_StringStream;
+ list->list = tmpdirs;
+ list->position = 0;
+ return list;
+}
+
+TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
+ return ::tensorflow::Env::Default()->NowNanos();
+}
+
+// Returns the number of microseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) {
+ return ::tensorflow::Env::Default()->NowMicros();
+}
+
+// Returns the number of seconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
+ return ::tensorflow::Env::Default()->NowSeconds();
+}
+
+void TF_DefaultThreadOptions(TF_ThreadOptions* options) {
+ options->stack_size = 0;
+ options->guard_size = 0;
+ options->numa_node = -1;
+}
+
+TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
+ const char* thread_name, void (*work_func)(void*),
+ void* param) {
+ ::tensorflow::ThreadOptions cc_options;
+ cc_options.stack_size = options->stack_size;
+ cc_options.guard_size = options->guard_size;
+ cc_options.numa_node = options->numa_node;
+ return reinterpret_cast(::tensorflow::Env::Default()->StartThread(
+ cc_options, thread_name, [=]() { (*work_func)(param); }));
+}
+
+void TF_JoinThread(TF_Thread* thread) {
+ // ::tensorflow::Thread joins on destruction
+ delete reinterpret_cast<::tensorflow::Thread*>(thread);
+}
diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h
new file mode 100644
index 0000000000000000000000000000000000000000..15652353cd7e1f1e7d7a4c665703c0166682d790
--- /dev/null
+++ b/tensorflow/c/env.h
@@ -0,0 +1,194 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+#include
+
+#ifndef TENSORFLOW_C_ENV_H_
+#define TENSORFLOW_C_ENV_H_
+
+#include "tensorflow/c/c_api.h"
+
+// --------------------------------------------------------------------------
+// C API for tensorflow::Env.
+
+struct TF_WritableFileHandle;
+struct TF_StringStream;
+struct TF_Thread;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct TF_FileStatistics {
+ // The length of the file in bytes.
+ int64_t length;
+ // The last modified time in nanoseconds.
+ int64_t mtime_nsec;
+ // Whether the name refers to a directory.
+ bool is_directory;
+} TF_FileStatistics;
+
+typedef struct TF_ThreadOptions {
+ // Thread stack size to use (in bytes), zero implies that the system default
+ // will be used.
+ size_t stack_size;
+
+ // Guard area size to use near thread stacks to use (in bytes), zero implies
+ // that the system default will be used.
+ size_t guard_size;
+
+ // The NUMA node to use, -1 implies that there should be no NUMA affinity for
+ // this thread.
+ int numa_node;
+} TF_ThreadOptions;
+
+// Creates the specified directory. Typical status code are:
+// * TF_OK - successfully created the directory
+// * TF_ALREADY_EXISTS - directory already exists
+// * TF_PERMISSION_DENIED - dirname is not writable
+TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status);
+
+// Deletes the specified directory. Typical status codes are:
+// * TF_OK - successfully deleted the directory
+// * TF_FAILED_PRECONDITION - the directory is not empty
+TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status);
+
+// Deletes the specified directory and all subdirectories and files underneath
+// it. This is accomplished by traversing the directory tree rooted at dirname
+// and deleting entries as they are encountered.
+//
+// If dirname itself is not readable or does not exist, *undeleted_dir_count is
+// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g.
+// TF_NOT_FOUND) is returned.
+//
+// If dirname and all its descendants were successfully deleted, TF_OK is
+// returned and both error counters are set to zero.
+//
+// Otherwise, while traversing the tree, undeleted_file_count and
+// undeleted_dir_count are updated if an entry of the corresponding type could
+// not be deleted. The returned error status represents the reason that any one
+// of these entries could not be deleted.
+//
+// Typical status codes:
+// * TF_OK - dirname exists and we were able to delete everything underneath
+// * TF_NOT_FOUND - dirname doesn't exist
+// * TF_PERMISSION_DENIED - dirname or some descendant is not writable
+// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not
+// implemented
+TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname,
+ int64_t* undeleted_file_count,
+ int64_t* undeleted_dir_count,
+ TF_Status* status);
+
+// Obtains statistics for the given path. If status is TF_OK, *stats is
+// updated, otherwise it is not touched.
+TF_CAPI_EXPORT extern void TF_FileStat(const char* filename,
+ TF_FileStatistics* stats,
+ TF_Status* status);
+
+// Creates or truncates the given filename and returns a handle to be used for
+// appending data to the file. If status is TF_OK, *handle is updated and the
+// caller is responsible for freeing it (see TF_CloseWritableFile).
+TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename,
+ TF_WritableFileHandle** handle,
+ TF_Status* status);
+
+// Closes the given handle and frees its memory. If there was a problem closing
+// the file, it is indicated by status. Memory is freed in any case.
+TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Syncs content of the handle to the filesystem. Blocks waiting for the
+// filesystem to indicate that the content has been persisted.
+TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Flush local buffers to the filesystem. If the process terminates after a
+// successful flush, the contents may still be persisted, since the underlying
+// filesystem may eventually flush the contents. If the OS or machine crashes
+// after a successful flush, the contents may or may not be persisted, depending
+// on the implementation.
+TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Appends the given bytes to the file. Any failure to do so is indicated in
+// status.
+TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle,
+ const char* data,
+ size_t length,
+ TF_Status* status);
+
+// Deletes the named file and indicates whether successful in *status.
+TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename,
+ TF_Status* status);
+
+// Retrieves the next item from the given TF_StringStream and places a pointer
+// to it in *result. If no more items are in the list, *result is set to NULL
+// and false is returned.
+//
+// Ownership of the items retrieved with this function remains with the library.
+// Item points are invalidated after a call to TF_StringStreamDone.
+TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list,
+ const char** result);
+
+// Frees the resources associated with given string list. All pointers returned
+// by TF_StringStreamNext are invalid after this call.
+TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list);
+
+// Retrieves the list of children of the given directory. You can iterate
+// through the list with TF_StringStreamNext. The caller is responsible for
+// freeing the list (see TF_StringStreamDone).
+TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
+ TF_Status* status);
+
+// Retrieves a list of directory names on the local machine that may be used for
+// temporary storage. You can iterate through the list with TF_StringStreamNext.
+// The caller is responsible for freeing the list (see TF_StringStreamDone).
+TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
+
+// Returns the number of nanoseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);
+
+// Returns the number of microseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void);
+
+// Returns the number of seconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
+
+// Populates a TF_ThreadOptions struct with system-default values.
+TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options);
+
+// Returns a new thread that is running work_func and is identified
+// (for debugging/performance-analysis) by thread_name.
+//
+// The given param (which may be null) is passed to work_func when the thread
+// starts. In this way, data may be passed from the thread back to the caller.
+//
+// Caller takes ownership of the result and must call TF_JoinThread on it
+// eventually.
+TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
+ const char* thread_name,
+ void (*work_func)(void*),
+ void* param);
+
+// Waits for the given thread to finish execution, then deletes it.
+TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // TENSORFLOW_C_ENV_H_
diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..687ad024137352662759ec1f43df87e89faca353
--- /dev/null
+++ b/tensorflow/c/env_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/env.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
+
+TEST(TestEnv, TestDirHandling) {
+ TF_StringStream* tempdirs = TF_GetLocalTempDirectories();
+ const char* tempdir;
+ bool found = false;
+ while (TF_StringStreamNext(tempdirs, &tempdir)) {
+ found = true;
+
+ TF_Status* s = TF_NewStatus();
+
+ ::tensorflow::string dirpath =
+ ::tensorflow::io::JoinPath(tempdir, "somedir");
+ TF_CreateDir(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": "
+ << TF_Message(s);
+
+ ::tensorflow::string filepath =
+ ::tensorflow::io::JoinPath(dirpath, "somefile.txt");
+ TF_WritableFileHandle* handle;
+ TF_NewWritableFile(filepath.c_str(), &handle, s);
+ ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": "
+ << TF_Message(s);
+
+ const char* data = "Hello, world!\n";
+ TF_AppendWritableFile(handle, data, strlen(data), s);
+ ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at "
+ << filepath << ": " << TF_Message(s);
+
+ TF_CloseWritableFile(handle, s);
+ ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to "
+ << filepath << ": " << TF_Message(s);
+
+ TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath;
+ const char* childpath;
+ ASSERT_TRUE(TF_StringStreamNext(children, &childpath));
+ ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt");
+ // There should only be one file in this directory.
+ ASSERT_FALSE(TF_StringStreamNext(children, &childpath));
+ ASSERT_EQ(childpath, nullptr);
+ TF_StringStreamDone(children);
+
+ TF_FileStatistics stats;
+ TF_FileStat(filepath.c_str(), &stats, s);
+ ASSERT_EQ(stats.length, strlen(data));
+ ASSERT_FALSE(stats.is_directory);
+ ASSERT_GT(stats.mtime_nsec, 0);
+
+ // Trying to delete a non-empty directory should fail.
+ TF_DeleteDir(dirpath.c_str(), s);
+ ASSERT_NE(TF_OK, TF_GetCode(s))
+ << "TF_DeleteDir unexpectedly succeeded with a non-empty directory "
+ << dirpath;
+
+ TF_DeleteFile(filepath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": "
+ << TF_Message(s);
+
+ // Now deleting the directory should work.
+ TF_DeleteDir(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": "
+ << TF_Message(s);
+
+ TF_DeleteStatus(s);
+ break;
+ }
+
+ ASSERT_TRUE(found) << "expected at least one temp dir";
+
+ TF_StringStreamDone(tempdirs);
+}
+
+TEST(TestEnv, TestTimeFunctions) {
+ ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000
+ ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
+ ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
+}
+
+namespace {
+
+struct SomeThreadData {
+ ::tensorflow::mutex mu;
+ bool did_work = false;
+};
+
+void SomeThreadFunc(void* data) {
+ auto* real_data = static_cast(data);
+ ::tensorflow::mutex_lock l(real_data->mu);
+ real_data->did_work = true;
+}
+
+} // namespace
+
+TEST(TestEnv, TestThreads) {
+ TF_ThreadOptions options;
+ TF_DefaultThreadOptions(&options);
+ SomeThreadData data;
+ TF_Thread* thread =
+ TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data);
+ TF_JoinThread(thread);
+ ::tensorflow::mutex_lock l(data.mu);
+ ASSERT_TRUE(data.did_work);
+}
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index ca69345264607ac689fb556b4f5c9bc08ea5eb88..2a4eaecb6cf2740a522b1e849d1306ebde6c4577 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include
+#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/kernels.h"
+#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -116,3 +118,43 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_SetStatus(status, TF_OK, "");
}
+
+int TF_NumInputs(TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ return cc_ctx->num_inputs();
+}
+
+int TF_NumOutputs(TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ return cc_ctx->num_outputs();
+}
+
+void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ if (i < 0 || i >= cc_ctx->num_inputs()) {
+ TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
+ return;
+ }
+ const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
+ TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
+ if (TF_GetCode(status) == TF_OK) {
+ *tensor = result;
+ }
+}
+
+void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ if (i < 0 || i >= cc_ctx->num_inputs()) {
+ TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
+ return;
+ }
+ ::tensorflow::Tensor cc_tensor;
+ ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+ if (s.ok()) {
+ cc_ctx->set_output(i, cc_tensor);
+ }
+}
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
index 2518789a3c141755d0b3373d53642c487331f68b..1a91aa184f11ac8e45b38a1d106c7b445747a7c1 100644
--- a/tensorflow/c/kernels.h
+++ b/tensorflow/c/kernels.h
@@ -85,6 +85,32 @@ TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
+// --------------------------------------------------------------------------
+// OpKernelContext routines
+
+// TF_NumInputs returns the number of inputs available in ctx.
+TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
+
+// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
+// kernel.
+TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
+
+// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
+// populated and its ownership is passed to the caller. In any other case,
+// *tensor is not modified.
+//
+// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
+TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
+ TF_Tensor** tensor, TF_Status* status);
+
+// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
+// TF_OK, ctx is left unmodified.
+//
+// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
+TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i,
+ const TF_Tensor* tensor,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
index e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861..e659ee3c3d258a626ccf03a782ec031b5a703a48 100644
--- a/tensorflow/c/kernels_test.cc
+++ b/tensorflow/c/kernels_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/op.h"
@@ -31,7 +32,6 @@ struct MyCustomKernel {
static bool delete_called = false;
static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
- LOG(INFO) << "Wow, actually got into creation";
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
@@ -51,12 +51,31 @@ static void MyDeleteFunc(void* kernel) {
delete s;
}
+namespace tensorflow {
+
+static std::unique_ptr GetFakeKernel(const char* device_name,
+ const char* op_name,
+ Status* status) {
+ NodeDef def;
+ def.set_op(op_name);
+ def.set_device(device_name);
+ def.add_input("input1");
+ def.add_input("input2");
+ return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
+ status);
+}
+
// Tests registration of a single C kernel and checks that calls through the
// C/C++ boundary are being made.
TEST(TestKernel, TestRegisterKernelBuilder) {
const char* kernel_name = "SomeKernelName";
const char* op_name = "FooOp";
- const char* device_name = "barDev";
+ const char* device_name = "FakeDeviceName1";
+
+ REGISTER_OP(op_name)
+ .Input("input1: double")
+ .Input("input2: uint8")
+ .Output("output1: uint8");
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
@@ -65,35 +84,120 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
- TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status);
+ TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
- ::tensorflow::KernelList list;
+ KernelList list;
list.ParseFromArray(buf->data, buf->length);
ASSERT_EQ(1, list.kernel_size());
- ASSERT_EQ("barDev", list.kernel(0).device_type());
+ ASSERT_EQ(device_name, list.kernel(0).device_type());
TF_DeleteBuffer(buf);
TF_DeleteStatus(status);
}
- REGISTER_OP("FooOp")
+ {
+ Status status;
+ std::unique_ptr kernel =
+ GetFakeKernel(device_name, op_name, &status);
+ TF_EXPECT_OK(status);
+ ASSERT_NE(nullptr, kernel.get());
+ kernel->Compute(nullptr);
+ }
+
+ ASSERT_TRUE(delete_called);
+}
+
+class DummyDevice : public DeviceBase {
+ public:
+ DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
+ bool RequiresRecordingAccessedTensors() const override { return save_; }
+ Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
+ return cpu_allocator();
+ }
+
+ private:
+ bool save_;
+};
+
+TEST(TestKernel, TestInputAndOutputCount) {
+ const char* kernel_name = "InputOutputCounterKernel";
+ const char* op_name = "BarOp";
+ const char* device_name = "FakeDeviceName2";
+
+ REGISTER_OP(op_name)
.Input("input1: double")
.Input("input2: uint8")
.Output("output1: uint8");
+ static int num_inputs = 0;
+ static int num_outputs = 0;
+
+ // A kernel whose Compute function has a side-effect of updating num_inputs
+ // and num_outputs. Various functions on TF_OpKernelContext are also
+ // exercised.
+ auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+ num_inputs = TF_NumInputs(ctx);
+ num_outputs = TF_NumOutputs(ctx);
+
+ TF_Tensor* input = nullptr;
+ TF_Status* s = TF_NewStatus();
+ TF_GetInput(ctx, 0, &input, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
+ EXPECT_EQ(123, *static_cast(TF_TensorData(input)));
+ TF_GetInput(ctx, -1, &input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+ TF_GetInput(ctx, 3, &input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+
+ // Copy the input tensor to output.
+ TF_SetOutput(ctx, 0, input, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s));
+
+ TF_SetOutput(ctx, 24, input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+
+ TF_DeleteStatus(s);
+ if (input != nullptr) {
+ TF_DeleteTensor(input);
+ }
+ };
+
+ TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
+ my_compute_func, nullptr);
+
+ {
+ TF_Status* status = TF_NewStatus();
+ TF_RegisterKernelBuilder(kernel_name, builder, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_DeleteStatus(status);
+ }
+
{
- ::tensorflow::NodeDef def;
- def.set_op("FooOp");
- def.set_device("bar");
- def.add_input("input1");
- def.add_input("input2");
- ::tensorflow::Status status;
- std::unique_ptr<::tensorflow::OpKernel> kernel =
- ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"),
- nullptr, nullptr, def, 1, &status);
+ OpKernelContext::Params p;
+ DummyDevice dummy_device(nullptr, false);
+ p.device = &dummy_device;
+
+ Tensor t(tensorflow::uint8(123));
+
+ gtl::InlinedVector inputs;
+ // Simulate 2 inputs
+ inputs.emplace_back(&t);
+ inputs.emplace_back();
+ p.inputs = &inputs;
+
+ Status status;
+ std::unique_ptr kernel =
+ GetFakeKernel(device_name, op_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
- kernel->Compute(nullptr);
- }
- ASSERT_TRUE(delete_called);
+ p.op_kernel = kernel.get();
+ OpKernelContext ctx(&p);
+ kernel->Compute(&ctx);
+
+ ASSERT_EQ(2, num_inputs);
+ ASSERT_EQ(1, num_outputs);
+ ASSERT_EQ(123, ctx.mutable_output(0)->scalar()());
+ }
}
+
+} // namespace tensorflow
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}
+void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
+ TF_Status* status) {
+ mutex_lock l(graph->mu);
+ status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
+ new_src.index, &dst->node);
+ if (status->status.ok()) {
+ // This modification only updates the destination node for
+ // the purposes of running this graph in a session. Thus, we don't
+ // record the source node as being modified.
+ RecordMutation(graph, *dst, "adding input tensor");
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
+// Updates 'dst' to consume 'new_src'.
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status);
@@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// because I couldn't get SWIG to work otherwise.
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status);
+
+// This method is used to add a new input edge to 'dst', which must be a While
+// op. The While op's "T" attribute must have already been updated to include
+// the new edge. This is used to construct tf.while_loop gradients.
+void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
+ TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py
index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..d58acde09f007bc9df40b08b0ef79c6031ca7941 100644
--- a/tensorflow/compat_template_v1.__init__.py
+++ b/tensorflow/compat_template_v1.__init__.py
@@ -23,12 +23,12 @@ import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+# API IMPORTS PLACEHOLDER
+
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index e0ac7130a64d3928c39440c0e10a2d2e1990b9cd..ab1c1be344e2257721507543bc7647d4ff4becb2 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -178,7 +178,7 @@ Status GenArgMethods(const tf2xla::Config& config,
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
- void set_arg{{NAME}}_data(void* data) {
+ void set_arg{{NAME}}_data(const void* data) {
set_arg_data({{I}}, data);
}
{{TYPE}}* arg{{NAME}}_data() {
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index a2cdab5d1a8e72504ca11b789287d4efd07a59e9..968afad65ed6d4b5510687df484b7ce6743f6a85 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
- void set_arg0_data(void* data) {
+ void set_arg0_data(const void* data) {
set_arg_data(0, data);
}
float* arg0_data() {
@@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1];
}
- void set_arg_myfeed_data(void* data) {
+ void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data);
}
float* arg_myfeed_data() {
@@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1];
}
- void set_arg1_data(void* data) {
+ void set_arg1_data(const void* data) {
set_arg_data(1, data);
}
tensorflow::int64* arg1_data() {
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..4051664c24cacad4a2d151ad3ac9009015900609 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -283,7 +283,7 @@ def tf_library(
)
# Variables used for gen_test and gen_benchmark.
- cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
+ cpp_class_split = cpp_class.rsplit("::", 2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index be91ed4f432b1890c22900f293fd4196e5c9d970..d8c88a9fca2db74265b4962e07a66ab214b1d994 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -76,6 +76,7 @@ cc_library(
srcs = ["xla_cpu_device.cc"],
visibility = [":friends"],
deps = [
+ ":create_xla_launch_op", # buildcleaner: keep
":flags",
":jit_compilation_passes",
":xla_device",
@@ -95,6 +96,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
+ ":create_xla_launch_op", # buildcleaner: keep
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
@@ -104,6 +106,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = 1,
)
@@ -512,6 +515,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -610,6 +614,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:scope",
@@ -622,6 +627,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla:test",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f478832781cb1dc045d9163d4a6f5e5f64a8a705..03aba97bbe81a11f6366d118ee5bc573d0c6b31b 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg(
if (inserted) {
NodeDef arg_def;
NodeDefBuilder builder(
- absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
+ NodeDebugInfo(src_node->def()));
DataType dtype = edge->dst()->input_type(edge->dst_input());
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
@@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult(
if (inserted) {
NodeDef ret_def;
NodeDefBuilder builder(
- absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
+ NodeDebugInfo(src_node->def()));
DataType dtype = src_node->output_type(src_slot);
builder.Attr("T", dtype);
builder.Attr("index", ret_index);
@@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
}
NodeDef host_compute_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_",
oc_subgraph_name, "_host_compute"),
kHostComputeOp);
@@ -1040,6 +1043,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
Graph* graph_out) {
if (sequencer_ == nullptr) {
NodeDef seq_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
builder.Device(device_);
@@ -1214,7 +1218,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder(
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
NodeDef key_def;
NodeDefBuilder builder(
- absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder");
+ absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder",
+ NodeDebugInfo(call_node_def_));
builder.Attr("dtype", DT_STRING);
builder.Attr("shape", shape_proto);
builder.Attr("_host_compute_call_node", call_node_def_.name());
@@ -1248,6 +1253,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
}
NodeDef recv_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
@@ -1303,6 +1309,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
}
NodeDef send_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_send"),
kSendFromHostOp);
@@ -1833,8 +1840,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port,
// Add any Enter nodes required to bring the constant to the correct control
// flow frame.
while (!control_flow_info[src_node->id()].frame_name.empty()) {
+ NodeDebugInfo debug_info(*src_node);
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
- options.op_registry());
+ options.op_registry(), &debug_info);
enter_builder.Attr("frame_name",
control_flow_info[src_node->id()].frame_name);
enter_builder.Attr("is_constant", true);
@@ -2018,7 +2026,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
return errors::InvalidArgument(
"Shape inference is not possible for outside_compilation "
"SendFromHost node ",
- send_node->name(), " because shape of node ", n->name(),
+ send_node->name(), " because shape of node ",
+ FormatNodeForError(*n),
" will not be known at compilation time.");
}
}
@@ -2047,8 +2056,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
return errors::Internal(
"Internal assumption failed while rewriting an outside_compilation "
"cluster that contains a while loop. Logic assumes back-edge is to "
- "port 1 of a 2-input "
- "Merge node.");
+ "port 1 of a 2-input Merge node.");
}
// Connect the existing edge to both inputs of the Merge node so that the
// graph will be well-formed.
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index de89be9a3555960dabe7bacd17226c15ae888ae6..8617beec004d0fe912155f054442c5b6249bb6b5 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -299,7 +299,7 @@ REGISTER_OP("XlaHostCompute")
.Attr("Toutputs: list(type) >= 0")
.Attr("ancestors: list(string) >= 0")
.Attr("key: string")
- .Attr("shape_inference_graph: string = ''")
+ .Attr("shape_inference_graph: func")
.Attr("shapes: list(shape) >= 0")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
@@ -510,11 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
if (!s.ok()) return s;
- s = PerformStaticShapeInferenceBeforeEncapsulation(
- graph.get(), "_encapsulate", "_outside");
- if (!s.ok()) return s;
-
- s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside");
+ s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get());
if (!s.ok()) return s;
std::unique_ptr graph_out;
@@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
graphdef->Swap(&graphdef_out);
*library = lib_def->ToProto();
+ // Remove "_xla_inferred_shapes" attr. They are added by
+ // `PerformStaticShapeInferenceBeforeEncapsulation`.
+ for (FunctionDef& fdef : *library->mutable_function()) {
+ for (NodeDef& node_def : *fdef.mutable_node_def()) {
+ node_def.mutable_attr()->erase("_xla_inferred_shapes");
+ }
+ }
+
return s;
}
@@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
@@ -931,8 +939,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
@@ -948,16 +955,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv, b})
+ .WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
@@ -966,9 +975,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call =
- b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder);
+ b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
- Binary(a, call, b2.opts().WithName("G").WithControlInputs({e}));
+ Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1022,14 +1031,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape1.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
@@ -1037,33 +1048,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
shape2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT, DT_FLOAT}, shape2.opts());
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* g = Binary(e, ops::NodeOut(recv2, 0),
+ shape2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2"));
Node* h = Binary(ops::NodeOut(recv2, 1), e,
shape2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
+ NameAttrList shape_inference_graph1, shape_inference_graph2;
+ shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1");
+ shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"},
"UnaryTest",
- {"outside_compilation_O2_host_compute:outputs:0"}},
+ {"outside_compilation_O2_host_compute:outputs:1"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
@@ -1073,11 +1096,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
"XlaHostCompute",
{"F:o:0", "D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O2"},
+ {"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
{"F"}},
@@ -1088,13 +1110,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph1},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"i_0_retval_retval", "I:o:0"}});
+ {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"},
+ {"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1105,19 +1127,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Binary(e, ops::NodeOut(recv2, 0),
b2.opts()
.WithName("G")
@@ -1130,7 +1155,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* send2 =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(b2.opts()
.WithName("F1_sequencer")
@@ -1139,12 +1165,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
- Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder);
+ Node* call =
+ b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
- Binary(g, call, b2.opts().WithName("J"));
+ Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1),
+ b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
-
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
@@ -1196,7 +1223,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
- {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float",
+ "d_0_retval_retval:float"},
+ {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1212,35 +1241,37 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"d_0_retval_retval", "D:o:0"},
+ {"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
- "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"},
- {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {},
+ "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"},
+ {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
- {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}},
+ {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {"G:o:0"},
- {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"d_0_arg", "G:o:0"},
+ {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F2_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}});
+ {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1251,16 +1282,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
@@ -1268,29 +1301,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
- b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
+ b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1",
- {DT_FLOAT}, b2.opts());
- Node* h = Binary(ops::NodeOut(call1, 1), recv2,
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* h = Binary(recv2, ops::NodeOut(recv2, 1),
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
- Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
- b2.opts());
+ Node* send2 =
+ SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
- node_builder2.Input(call1).Input(e);
+ node_builder2.Input(call1)
+ .Input(ops::NodeOut(call1, 1))
+ .Input(ops::NodeOut(call1, 2));
Node* call2 = b2.opts()
- .WithControlInputs({s2, e, call1})
+ .WithControlInputs({s2, call1})
.FinalizeBuilder(&node_builder2);
- Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J"));
+ Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1326,8 +1363,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* h = Unary(g, b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
- .WithAttr("_outside", "O1")
- .WithControlInput(e));
+ .WithAttr("_outside", "O1"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(f, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
@@ -1358,7 +1394,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
@@ -1380,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F2_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
@@ -1401,7 +1437,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
@@ -1413,7 +1449,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
- b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
+ b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
@@ -1422,8 +1458,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* h = Unary(recv2, b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
- .WithAttr("_outside", "O1")
- .WithControlInput(e));
+ .WithAttr("_outside", "O1"));
Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
b2.opts());
@@ -1484,12 +1519,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {},
- {{"Tinputs", absl::Span({})},
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
@@ -1503,16 +1538,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1");
+ b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
@@ -1569,12 +1607,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {},
- {{"Tinputs", absl::Span({})},
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
@@ -1591,13 +1629,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts());
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithControlInput(recv1)
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithControlInput(recv1)
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
@@ -1644,8 +1682,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1654,14 +1711,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
- {"Toutputs", absl::Span({})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1678,14 +1736,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1");
+ b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1722,8 +1783,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1736,14 +1816,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
- {"Toutputs", absl::Span({})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1760,7 +1841,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {},
+ Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
@@ -1770,7 +1851,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1813,22 +1894,45 @@ TEST(EncapsulateSubgraphsTest,
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, shape2.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
+ NameAttrList shape_inference_graph1;
+ shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1");
+ NameAttrList shape_inference_graph2;
+ shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1836,6 +1940,16 @@ TEST(EncapsulateSubgraphsTest,
{{"H"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph1},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
@@ -1843,12 +1957,12 @@ TEST(EncapsulateSubgraphsTest,
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O2"},
+ {"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}}},
},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -1856,30 +1970,39 @@ TEST(EncapsulateSubgraphsTest,
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
-
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, b2.opts());
- Node* g = Unary(recv, b2.opts()
- .WithName("G")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O2")
- .WithControlInput(e));
- Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts());
- Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
- "F1");
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* g = Unary(recv2, b2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2")
+ .WithControlInput(e));
+ Node* send2 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* s1 = Sequencer(b2.opts()
+ .WithName("F1_sequencer")
+ .WithControlInputs({recv1, send1, recv2, send2}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("I"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1925,19 +2048,24 @@ TEST(EncapsulateSubgraphsTest,
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1945,6 +2073,16 @@ TEST(EncapsulateSubgraphsTest,
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
+ {{"outside_compilation_O2_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O2"},
+ {"shape_inference_graph", NameAttrList()},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O2"}}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@@ -1952,12 +2090,12 @@ TEST(EncapsulateSubgraphsTest,
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -1968,27 +2106,33 @@ TEST(EncapsulateSubgraphsTest,
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
- Node* e = Unary(recv, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
- /*Node* g =*/Unary(a, b2.opts()
- .WithName("G")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O2")
- .WithControlInput(e));
- Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
- "F1");
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ /*Node* g =*/Unary(recv2, b2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2")
+ .WithControlInput(e));
+ Node* s1 = Sequencer(b2.opts()
+ .WithName("F1_sequencer")
+ .WithControlInputs({recv1, recv2, send}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("I"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2039,19 +2183,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
@@ -2063,8 +2212,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
@@ -2074,7 +2222,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
{}},
@@ -2085,11 +2233,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O3"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O3"}},
{}}},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -2100,23 +2249,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, b2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
- Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3",
- {DT_FLOAT}, b2.opts());
+ Node* recv3 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
/*Node* i =*/Binary(recv3, e,
b2.opts()
.WithName("I")
@@ -2131,7 +2284,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("J"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2167,14 +2320,44 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -2183,15 +2366,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* key_constant =
+ KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
+ Node* recv =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(recv, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* s = Sequencer(
+ b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
- node_builder1.Input(a).Input(b);
+ node_builder1.Input(a).Input(b).ControlInput(s);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2236,20 +2430,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape.opts());
- Node* a = InputShaped(shape.opts().WithName("A"));
- Node* c = Unary(a, shape.opts().WithName("C"));
- Node* e = BinaryUnknownShape(c, recv,
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
@@ -2262,13 +2458,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {"c:o:0"},
- {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"c_0_arg", "c:o:0"},
+ {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
@@ -2285,16 +2480,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
- Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0),
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv, b})
+ .WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
@@ -2303,9 +2500,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(b).Input(c);
Node* call =
- b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder);
+ b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder);
- Binary(a, call, b2.opts().WithName("G").WithControlInputs({e}));
+ Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc
index bcc3213285bee2a2094bd6c39b37ba95874d90ed..2264806d6bdabd9f26d9f83b681524399f996317 100644
--- a/tensorflow/compiler/jit/encapsulate_util.cc
+++ b/tensorflow/compiler/jit/encapsulate_util.cc
@@ -62,516 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
n->AddAttr(attr_name, value);
}
-// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges to remove. We should not remove the edge while iterating.
- std::vector edges_to_remove;
- for (const Edge* e : g->edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
-
- auto src_xla_computation =
- GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_xla_computation =
- GetStringAttr(*e->dst(), xla_computation_attr_name);
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
-
- if (!src_xla_computation && !dst_xla_computation) {
- continue;
- } else if (src_xla_computation && !dst_xla_computation) {
- if (src_outside_compilation) {
- // Case 1c: outside compilation to host computation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else if (!src_xla_computation && dst_xla_computation) {
- if (dst_outside_compilation) {
- // Case 1c: host computation control to outside compilation edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else { // src_xla_computation && dst_xla_computation
- if (*src_xla_computation != *dst_xla_computation) {
- if (src_outside_compilation && dst_outside_compilation) {
- // Case 1b: outside compilation to outside compilation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- } else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1a: outside compilation to another XLA computaition control
- // edge.
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->src(), kXlaConnectedToOtherXlaComputationAttrName,
- *dst_xla_computation));
- } else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1a: another XLA computaition to outside compilation control
- // edge.
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
- *src_xla_computation));
- }
- }
- }
- }
-
- for (auto e : edges_to_remove) {
- g->RemoveEdge(e);
- }
- return Status::OK();
-}
-
-// Step 2 for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessXlaToXlaDataEdges(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges between XLA computations. Notice that we do not store `Edge*`
- // directly because we remove some nodes while adding Identity nodes, and
- // those Edge pointers might be invalidated.
- struct EdgeInfo {
- int dst_input, dst_node_id;
- };
- std::vector edges;
- for (const Edge* e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- auto src_xla_computation =
- GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_xla_computation =
- GetStringAttr(*e->dst(), xla_computation_attr_name);
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
- if (!src_xla_computation || !dst_xla_computation) {
- continue;
- }
-
- if (*src_xla_computation != *dst_xla_computation) {
- if (src_outside_compilation || dst_outside_compilation) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
- VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
- }
- }
- }
-
- // For each XLA -> XLA edge, add an Identity node between src and dst.
- for (int i = 0; i < edges.size(); i++) {
- Node* dst = g->FindNodeId(edges[i].dst_node_id);
- const Edge* e;
- TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
- Node* src = e->src();
- int src_output = e->src_output(), dst_input = e->dst_input();
- g->RemoveEdge(e);
-
- // Create Identity node, and connect it between `src` and `dst`.
- string identity_node_name =
- absl::StrCat("bridge_", src->name(), "_", dst->name());
- DataType dtype = src->output_type(src_output);
- TF_ASSIGN_OR_RETURN(Node * identity_node,
- BuildIdentityNode(g, identity_node_name, dtype, src,
- /*requested_device=*/absl::nullopt));
- identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name());
- g->AddEdge(src, src_output, identity_node, 0);
- g->AddEdge(identity_node, 0, dst, dst_input);
-
- // Replace `e->dst()` because its input node changed.
- NodeDef new_def = dst->def();
- *new_def.mutable_input(dst_input) = identity_node->name();
- TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
-
- // Other edge in `edges` might have `e->dst()` as src or dst
- // node. Before removing `e->dst()`, replace those edges with corresponding
- // edges for `dst_replace_node`.
- for (int j = i + 1; j < edges.size(); j++) {
- if (edges[j].dst_node_id == edges[i].dst_node_id) {
- edges[j].dst_node_id = dst_replace_node->id();
- }
- }
- }
- return Status::OK();
-}
-
-// Step 3 for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges between outside compilation and host computation. Notice that
- // we do not store `Edge*` directly because we remove some nodes while adding
- // Identity nodes, and those Edge pointers might be invalidated.
- struct EdgeInfo {
- int dst_input, dst_node_id;
- bool is_host_to_outside_compilation;
- };
- std::vector edges;
- for (const Edge* e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr &&
- e->dst()->attrs().Find(xla_computation_attr_name) != nullptr &&
- e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
- /*is_host_to_outside_compilation=*/true});
- VLOG(4) << "Host -> oc edge: " << e->DebugString();
- } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr &&
- e->src()->attrs().Find(xla_computation_attr_name) != nullptr &&
- e->src()->attrs().Find(outside_compilation_attr_name) !=
- nullptr) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
- /*is_host_to_outside_compilation=*/false});
- VLOG(4) << "Oc -> host edge: " << e->DebugString();
- }
- }
-
- // Remove the edge from host to outside compilation. Add a placeholder as
- // outside compilation node input.
- std::map placeholders;
- for (int i = 0; i < edges.size(); i++) {
- Node* dst = g->FindNodeId(edges[i].dst_node_id);
- const Edge* e;
- TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
- Node* src = e->src();
- int src_output = e->src_output(), dst_input = e->dst_input();
- g->RemoveEdge(e);
-
- // Find or create placeholder node.
- string new_name =
- edges[i].is_host_to_outside_compilation
- ? absl::StrCat(src->name(), "_host_to_oc_placeholder")
- : absl::StrCat(src->name(), "_oc_to_host_placeholder");
- auto iter = placeholders.find(new_name);
- Node* placeholder_node;
- if (iter == placeholders.end()) {
- NodeDefBuilder placeholder_builder(new_name, "Placeholder");
- placeholder_builder.Attr("dtype", src->output_type(src_output));
- if (edges[i].is_host_to_outside_compilation) {
- placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName,
- src->name());
- placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName,
- src_output);
- // If this placeholder node is in outside compilation, we need to set
- // `xla_computation_attr_name` and `outside_compilation_attr_name`.
- string xla_computation_attr, outside_compilation_attr;
- TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name,
- &xla_computation_attr));
- TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
- outside_compilation_attr_name,
- &outside_compilation_attr));
- placeholder_builder.Attr(xla_computation_attr_name,
- xla_computation_attr);
- placeholder_builder.Attr(outside_compilation_attr_name,
- outside_compilation_attr);
- } else {
- placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName,
- src->name());
- placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName,
- src_output);
- }
- NodeDef placeholder_def;
- TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
- Status s;
- placeholder_node = g->AddNode(placeholder_def, &s);
- TF_RETURN_IF_ERROR(s);
- placeholders[new_name] = placeholder_node;
- } else {
- placeholder_node = iter->second;
- }
- g->AddEdge(placeholder_node, 0, dst, dst_input);
-
- // Replace `e->dst()` because its input node changed.
- NodeDef new_def = dst->def();
- *new_def.mutable_input(dst_input) = placeholder_node->name();
- TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
-
- // Other edge in `edges` might have `e->dst()` as src or dst
- // node. Before removing `e->dst()`, replace those edges with corresponding
- // edges for `dst_replace_node`.
- for (int j = i + 1; j < edges.size(); j++) {
- if (edges[j].dst_node_id == edges[i].dst_node_id) {
- edges[j].dst_node_id = dst_replace_node->id();
- }
- }
- }
- return Status::OK();
-}
-
-// Step 1 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) {
- // Gather all outside compilation to host computation nodes.
- struct PlaceHolderNodeInfo {
- Node* n;
- bool is_host_to_oc;
- };
- std::vector placeholder_nodes;
- for (Node* n : g->nodes()) {
- if (n->type_string() == "Placeholder") {
- if (HasNodeAttr(n->def(),
- kOutsideCompilationToHostOriginalNodeAttrName)) {
- placeholder_nodes.push_back({n, false});
- } else if (HasNodeAttr(n->def(),
- kHostToOutsideCompilationOriginalNodeAttrName)) {
- placeholder_nodes.push_back({n, true});
- }
- }
- }
-
- // Remove the placeholder nodes, and reconnect original edge.
- auto node_name_index = g->BuildNodeNameIndex();
- for (auto placeholder_iter : placeholder_nodes) {
- Node* n = placeholder_iter.n;
-
- string node_name;
- int node_src_output;
- if (placeholder_iter.is_host_to_oc) {
- TF_RETURN_IF_ERROR(
- GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName,
- &node_name));
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
- kHostToOutsideCompilationSrcOutputAttrName,
- &node_src_output));
- } else {
- TF_RETURN_IF_ERROR(
- GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName,
- &node_name));
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
- kOutsideCompilationToHostSrcOutputAttrName,
- &node_src_output));
- }
- auto iter = node_name_index.find(node_name);
- if (iter == node_name_index.end()) {
- return errors::Internal(
- "Cannot find original node for oc -> host placeholder node ",
- node_name);
- }
-
- // Change all usage node to use the original node instead.
- Node* original_node = iter->second;
- std::vector control_edges;
- std::vector data_edges;
- for (auto e : n->out_edges()) {
- if (e->IsControlEdge()) {
- control_edges.push_back(e);
- } else {
- data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
- }
- }
- for (const Edge* e : control_edges) {
- g->AddControlEdge(original_node, e->dst());
- g->RemoveEdge(e);
- }
- for (int i = 0; i < data_edges.size(); i++) {
- Node* dst = data_edges[i].dst;
- NodeDef new_def = dst->def();
- int dst_input = data_edges[i].dst_input;
- *new_def.mutable_input(dst_input) =
- absl::StrCat(original_node->name(), ":", node_src_output);
- TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
-
- const Edge* edge_to_replace = nullptr;
- TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
- g->RemoveEdge(edge_to_replace);
- g->AddEdge(original_node, node_src_output, replace_node, dst_input);
-
- // Other edges might have `dst` as dst node. Update those edges with
- // `replace_node`.
- for (int j = i + 1; j < data_edges.size(); j++) {
- if (data_edges[j].dst == dst) {
- data_edges[j].dst = replace_node;
- }
- }
-
- // Other placeholder node might have `dst` as original node. Update
- // `node_name_index` with `replace_node`.
- node_name_index[replace_node->name()] = replace_node;
- }
-
- // Remove placeholder node.
- g->RemoveNode(n);
- }
- return Status::OK();
-}
-
-// Step 2 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) {
- // Gather Identity nodes to remove.
- std::vector bridge_nodes;
- for (Node* n : g->nodes()) {
- if (n->type_string() == "Identity" &&
- HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) {
- bridge_nodes.push_back(n);
- }
- }
-
- // Remove the identity nodes, and reconnect the original edge.
- for (int i = 0; i < bridge_nodes.size(); i++) {
- Node* n = bridge_nodes[i];
- const Edge* src_edge = nullptr;
- TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge));
-
- // Change all usage node to use the original node instead.
- std::vector control_edges;
- std::vector data_edges;
- for (auto e : n->out_edges()) {
- if (e->IsControlEdge()) {
- control_edges.push_back(e);
- } else {
- data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
- }
- }
- for (const Edge* e : control_edges) {
- g->AddControlEdge(src_edge->src(), e->dst());
- g->RemoveEdge(e);
- }
- for (int j = 0; j < data_edges.size(); j++) {
- Node* dst = data_edges[j].dst;
- NodeDef new_def = dst->def();
- int dst_input = data_edges[j].dst_input;
- *new_def.mutable_input(dst_input) =
- absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output());
- TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
-
- const Edge* edge_to_replace = nullptr;
- TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
- g->RemoveEdge(edge_to_replace);
- g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node,
- dst_input);
-
- // Other edges might have `dst` as dst node. Update those edges with
- // `replace_node`.
- for (int k = j + 1; k < data_edges.size(); k++) {
- if (data_edges[k].dst == dst) {
- data_edges[k].dst = replace_node;
- }
- }
-
- // The node we replaced might be in `bridge_nodes`. If so, update
- // `bridge_nodes` to use the replaced node.
- for (int k = i + 1; k < bridge_nodes.size(); k++) {
- if (bridge_nodes[k] == dst) {
- bridge_nodes[k] = replace_node;
- }
- }
- }
-
- // Remove Identity node.
- g->RemoveNode(n);
- }
- return Status::OK();
-}
-
-// Step 3 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-// We do not need to worry about removed nodes in step 1 and 2;
-// `PreprocessForEncapsulation` will not record control dependencies for those
-// remvoed nodes in the first place.
-Status AddControlDependencies(
- Graph* g, const std::unordered_map& cluster_node_names) {
- auto node_name_index = g->BuildNodeNameIndex();
-
- // Reconnect outside compilation to outside compilation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s =
- GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaControlDependenciesAttrName);
- for (const string& control_input : control_deps) {
- auto iter = node_name_index.find(control_input);
- if (iter == node_name_index.end()) {
- return errors::Internal("Cannot find original node for ",
- control_input);
- }
- g->AddControlEdge(iter->second, n);
- }
- }
- }
-
- // Reconnect outside compilation to XLA computation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s = GetNodeAttr(
- n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName);
- for (const string& control_input : control_deps) {
- auto iter = cluster_node_names.find(control_input);
- if (iter == cluster_node_names.end()) {
- return errors::Internal("Cannot find cluster node for ",
- control_input);
- }
- auto iter2 = node_name_index.find(iter->second);
- if (iter2 == node_name_index.end()) {
- return errors::Internal("Cannot find cluster node for ",
- iter->second);
- }
- g->AddControlEdge(n, iter2->second);
- }
- }
- }
-
- // Reconnect XLA computation to outside compilation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s =
- GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName,
- &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName);
- for (const string& control_input : control_deps) {
- auto iter = cluster_node_names.find(control_input);
- if (iter == cluster_node_names.end()) {
- return errors::Internal("Cannot find cluster node for ",
- control_input);
- }
- auto iter2 = node_name_index.find(iter->second);
- if (iter2 == node_name_index.end()) {
- return errors::Internal("Cannot find cluster node for ",
- iter->second);
- }
- g->AddControlEdge(iter2->second, n);
- }
- }
- }
-
- return Status::OK();
-}
-
// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
// `PreprocessEdgesBetweenOutsideCompilations` for details.
Status PreprocessControlEdgesBetweenOutsideCompilations(
@@ -642,7 +132,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
// Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input.
- std::map placeholders;
+ std::map, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e;
@@ -652,8 +142,10 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
g->RemoveEdge(e);
// Find or create placeholder node.
- string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder");
- auto iter = placeholders.find(new_name);
+ string new_name =
+ absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
+ auto placeholder_index = std::make_pair(src->name(), src_output);
+ auto iter = placeholders.find(placeholder_index);
Node* placeholder_node;
if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
@@ -673,7 +165,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
Status s;
placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s);
- placeholders[new_name] = placeholder_node;
+ placeholders[placeholder_index] = placeholder_node;
} else {
placeholder_node = iter->second;
}
@@ -808,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations(
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
-const char kXlaConnectedToOtherXlaComputationAttrName[] =
- "_xla_connected_to_other_xla_computation";
-const char kXlaConnectedFromOtherXlaComputationAttrName[] =
- "_xla_connected_from_other_xla_computation";
-const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies";
-const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src";
-const char kOutsideCompilationToHostOriginalNodeAttrName[] =
- "_xla_oc_to_host_node_name";
-const char kOutsideCompilationToHostSrcOutputAttrName[] =
- "_xla_oc_to_host_src_output";
-const char kHostToOutsideCompilationOriginalNodeAttrName[] =
- "_xla_host_to_oc_node_name";
-const char kHostToOutsideCompilationSrcOutputAttrName[] =
- "_xla_host_to_oc_src_output";
const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
const char kXlaConnectedFromXlaComputationAttrName[] =
@@ -832,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
"_xla_control_dependencies_within_xla_cluster";
-Status PerformStaticShapeInferenceBeforeEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Find all outside compilation to XLA computation data edges.
- std::unordered_set outside_compilation_send_nodes;
- for (auto e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name);
- if (!src_computation || !dst_computation ||
- *src_computation != *dst_computation) {
- continue;
- }
-
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
- if (src_outside_compilation && !dst_outside_compilation) {
- outside_compilation_send_nodes.insert(e->src());
- }
- }
-
+Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
// Perform shape inference.
std::map arg_shapes;
GraphShapeInfo shape_info;
@@ -865,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
// Add attribute for output shapes.
- for (Node* n : outside_compilation_send_nodes) {
- auto iter = shape_info.find(n->name());
- if (iter == shape_info.end()) {
- continue;
- }
-
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto iter : shape_info) {
std::vector output_shapes;
- std::transform(iter->second.begin(), iter->second.end(),
+ std::transform(iter.second.begin(), iter.second.end(),
std::back_inserter(output_shapes),
[](const InferredShape& inferred_shape) {
return inferred_shape.shape;
});
+ Node* n = node_name_index[iter.first];
n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
}
return Status::OK();
}
-Status PreprocessForEncapsulation(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name,
- outside_compilation_attr_name));
- TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name,
- outside_compilation_attr_name));
- TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
- g, xla_computation_attr_name, outside_compilation_attr_name));
- return Status::OK();
-}
-
-Status PostprocessForEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name,
- const std::unordered_map& clusters) {
- // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2,
- // but the node name won't change. Record cluster node name for
- // `AddControlDependencies`.
- std::unordered_map cluster_node_names;
- for (const auto& iter : clusters) {
- cluster_node_names[iter.first] = iter.second.node->name();
- }
-
- TF_RETURN_IF_ERROR(
- RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g));
- TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g));
- TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names));
- return Status::OK();
-}
-
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges
diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h
index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644
--- a/tensorflow/compiler/jit/encapsulate_util.h
+++ b/tensorflow/compiler/jit/encapsulate_util.h
@@ -27,51 +27,13 @@ namespace tensorflow {
// a list of PartialTensorShape objects.
extern const char kXlaInferredShapesAttrName[];
-// Infer output shapes for outside compilation nodes which have output data
-// edges to XLA computation nodes. These shapes will be used later by XLA
-// compiler as output shapes of the outside compilation's XlaHostCompute op.
-// XLA computation nodes will be mark by attr `xla_computation_attr_name`;
-// outside compilation nodes will be marked by both attr
-// `xla_computation_attr_name` and `outside_compilation_attr_name`.
-//
-// Those outside compilation nodes will be marked with attribute
-// `kXlaInferredShapesAttrName`.
+// Infers output shapes for all nodes in graph `g`. The output shapes will be
+// stored in node attribute `kXlaInferredShapesAttrName`.
//
// We have to perform shape inference before encapsulation because after
// encapsulation, some nodes will be encapsulated into function call, and shape
// inference does not handle function call at the moment.
-Status PerformStaticShapeInferenceBeforeEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name);
-
-// Attribute indicating that some ops in other XLA computation has control
-// dependency on this node. Attribute value will be a list of string (XLA
-// computation names).
-extern const char kXlaConnectedToOtherXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependency on some ops in
-// other XLA computation. Attribute value will be a list of string (XLA
-// computation names).
-extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependencies on some other
-// nodes. Attribute value will be a list of string (node names).
-extern const char kXlaControlDependenciesAttrName[];
-
-// Attribute indicating that this is an Identity node added to act as a bridge
-// between different XLA computations. Attribute value will be string (source
-// node name).
-extern const char kBridgeSourceNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an outside compilation node. Attribute value will be
-// string (original input node name).
-extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an outside compilation node. Attribute value will be
-// int (src_output for original edge).
-extern const char kOutsideCompilationToHostSrcOutputAttrName[];
+Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
@@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[];
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an host node. Attribute value will be string
-// (original input node name).
-extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for a host node. Attribute value will be int (src_output
-// for original edge).
-extern const char kHostToOutsideCompilationSrcOutputAttrName[];
-
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
@@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[];
// (node names).
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
-// Preprocesses edges between different XLA clusters for encapsulation. It will
-// perform the following operations in order:
-//
-// 1a. For control edges between outside compilation and another XLA
-// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
-// = XLA computation node name" to the outside compilation node.
-// 1b. For control edges between different outside compilations (in different
-// XLA computations), remove the edge and add attr
-// "kXlaControlDependenciesAttrName = src node name" to dst node.
-// 1c. For control edges between outside compilation and host computation,
-// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
-// name" to dst node.
-// 2. For data edges between different XLA computations, if either src or dst
-// is outside compilation, add an Identity node in between the edge. The
-// identity node will have attr kBridgeSourceNodeAttrName.
-// 3. For data edges between outside compilation and host computation, remove
-// the edge and create a Placeholder node as dst node's input.
-Status PreprocessForEncapsulation(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name);
-
// Information for XLA computation.
struct XlaClusterInfo {
// Add an explicitly-defined default constructor for this class.
@@ -158,24 +89,6 @@ struct XlaClusterInfo {
const std::map host_compute_core;
};
-// Postprocesses edges between different XLA clusters for encapsulation. This
-// function reverts what `PreprocessForEncapsulation` did. It will perform the
-// following operations in order:
-//
-// 1. Remove Placeholder nodes between outside compilation and host computation
-// (created in `PreprocessForEncapsulation` step 3).
-// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
-// 3a. Reconnect control edges between outside compilation and another XLA
-// computation (marked by `PreprocessForEncapsulation` step 1a).
-// 3b. Reconnect control edges between different outside compilations (marked by
-// `PreprocessForEncapsulation` step 1b).
-// 3c. Reconnect control edges between outside compilation and host computation
-// (marked by `PreprocessForEncapsulation` step 1c).
-Status PostprocessForEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name,
- const std::unordered_map& clusters);
-
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
index 25c32cef01d7f9877a35001457539f2ad189192f..3bb979e0698d2d6be42ed5bae66c25267928192c 100644
--- a/tensorflow/compiler/jit/encapsulate_util_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
- // "add" node is outside compilation node, "identity" node is XLA node.
- auto node_index = g.BuildNodeNameIndex();
- Node *add_node = node_index["add"], *identity_node = node_index["identity"];
- add_node->AddAttr("_xla", "cluster");
- add_node->AddAttr("_oc", "cluster");
- identity_node->AddAttr("_xla", "cluster");
- TF_CHECK_OK(
- PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
+ TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g));
- // Check that only "add" node now has _xla_inferred_shapes attr.
- std::vector nodes_with_inferred_shape;
- for (Node *n : g.nodes()) {
- if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
- nodes_with_inferred_shape.push_back(n);
- }
- }
- EXPECT_EQ(nodes_with_inferred_shape.size(), 1);
- EXPECT_EQ(nodes_with_inferred_shape[0], add_node);
+ // Check that "add" node now has _xla_inferred_shapes attr.
+ auto node_index = g.BuildNodeNameIndex();
+ Node *add_node = node_index["add"];
std::vector output_shapes;
TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName,
&output_shapes));
@@ -66,293 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
EXPECT_EQ(shape_proto.dim(0).size(), 2);
}
-TEST(PreprocessForEncapsulationTest, ControlEdges) {
- // Build the graph:
- // "const_0" and "const_1" in host computation
- // "add" = "const_0" + "const_1" in XLA computation 0
- // "identity0" = "add" in XLA computation 0 & outside compilation 0
- // "identity1" = "identity0" in XLA computation 0
- // "identity2" = "identity1" in host computation
- // "identity3" = "identity2" in XLA computation 1
- // "identity4" = "identity3" in XLA computation 1 & outside compilation 1
- // "identity5" = "identity4" in XLA computation 1
- // "identity6" = "identity5" in host computation
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
- Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
- Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
- Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
- Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
- Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
- Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
- Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr, and add control edges.
- Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
- *identity0_node = node_index["identity0"],
- *identity1_node = node_index["identity1"],
- *identity2_node = node_index["identity2"],
- *identity3_node = node_index["identity3"],
- *identity4_node = node_index["identity4"],
- *identity5_node = node_index["identity5"];
- add_node->AddAttr("_xla", "0");
- identity0_node->AddAttr("_xla", "0");
- identity0_node->AddAttr("_oc", "0");
- identity1_node->AddAttr("_xla", "0");
- identity3_node->AddAttr("_xla", "1");
- identity4_node->AddAttr("_xla", "1");
- identity4_node->AddAttr("_oc", "0");
- identity5_node->AddAttr("_xla", "1");
- // Case 1a: control edges between outside compilation and another XLA
- // computation.
- g.AddControlEdge(identity0_node, identity3_node);
- g.AddControlEdge(identity1_node, identity4_node);
- // Case 1b: control edges between different outside compilations.
- g.AddControlEdge(identity0_node, identity4_node);
- // Case 1c: control edges between outside compilation and host computation.
- g.AddControlEdge(const0_node, identity0_node);
- g.AddControlEdge(identity0_node, identity2_node);
-
- TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
-
- // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
- // to the outside compilation node.
- std::vector attr;
- TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
- kXlaConnectedToOtherXlaComputationAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "1");
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
- kXlaConnectedFromOtherXlaComputationAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "0");
- // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "identity0");
- // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "const_0");
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "identity0");
-}
-
-TEST(PreprocessForEncapsulationTest, DataEdges) {
- // Build the graph:
- // "const_0" and "const_1" in host computation
- // "add0" = "const_0" + "const_1" in XLA computation 0
- // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
- // "identity0" = "add1" in XLA computation 0
- // "add2" = "add1" + "identity0" in host computation
- // "add3" = "add1" + "add2" in XLA computation 1
- // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
- // "identity1" = "add4" in XLA computation 1
- // "identity2" = "identity1" in host computation
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
- Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
- Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
- Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
- Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
- Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
- Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
- Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
- Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr.
- Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
- *identity0_node = node_index["identity0"],
- *add3_node = node_index["add3"], *add4_node = node_index["add4"],
- *identity1_node = node_index["identity1"];
- add0_node->AddAttr("_xla", "0");
- add1_node->AddAttr("_xla", "0");
- add1_node->AddAttr("_oc", "0");
- identity0_node->AddAttr("_xla", "0");
- add3_node->AddAttr("_xla", "1");
- add4_node->AddAttr("_xla", "1");
- add4_node->AddAttr("_oc", "0");
- identity1_node->AddAttr("_xla", "1");
-
- TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
-
- // Check input nodes for related data edges.
- node_index = g.BuildNodeNameIndex();
- // Step 2: add an Identity node between different XLA computations.
- Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
- EXPECT_NE(bridge_add1_add3, nullptr);
- string str;
- TF_CHECK_OK(
- GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
- EXPECT_EQ(str, "add1");
- Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
- EXPECT_NE(bridge_identity0_add4, nullptr);
- // Step 3: add placeholder for edges between host computation and outside
- // compilation.
- EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
- Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
- TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
- kOutsideCompilationToHostOriginalNodeAttrName, &str));
- EXPECT_EQ(str, "add1");
- int i;
- TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
- kOutsideCompilationToHostSrcOutputAttrName, &i));
- EXPECT_EQ(i, 0);
- add4_node = node_index["add4"];
- ASSERT_NE(add4_node, nullptr);
- EXPECT_EQ(add4_node->def().input(0),
- "bridge_identity0_add4_host_to_oc_placeholder");
- Node *identity0_host_to_oc_placeholder =
- node_index["bridge_identity0_add4_host_to_oc_placeholder"];
- TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
- kHostToOutsideCompilationOriginalNodeAttrName, &str));
- EXPECT_EQ(str, "bridge_identity0_add4");
- TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
- kHostToOutsideCompilationSrcOutputAttrName, &i));
- EXPECT_EQ(i, 0);
-}
-
-TEST(PostprocessForEncapsulationTest, ControlEdges) {
- // Build the graph:
- // "const0"
- // "identity0" = "const0" (XLA computation 0)
- // "identity1" = "identity0"
- // "identity2" = "identity1" (XLA computation 1)
- // "identity3" = "identity2"
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
- Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
- Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
- Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr, and add control edges.
- Node *const0_node = node_index["const0"],
- *identity0_node = node_index["identity0"],
- *identity1_node = node_index["identity1"],
- *identity2_node = node_index["identity2"],
- *identity3_node = node_index["identity3"];
- identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName,
- std::vector{"0"});
- identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName,
- std::vector{"1"});
- identity3_node->AddAttr(kXlaControlDependenciesAttrName,
- std::vector{"const0", "identity1"});
-
- std::unordered_map clusters;
- clusters["0"].node = identity0_node;
- clusters["1"].node = identity2_node;
- TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
-
- // Case 3a: we have control edge identity0 -> identity1, and identity1 ->
- // identity2.
- bool edge_identity0_identity1 = false, edge_identity1_identity2 = false;
- for (const Edge *e : g.edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
- if (e->src() == identity0_node && e->dst() == identity1_node) {
- edge_identity0_identity1 = true;
- } else if (e->src() == identity1_node && e->dst() == identity2_node) {
- edge_identity1_identity2 = true;
- }
- }
- EXPECT_TRUE(edge_identity0_identity1);
- EXPECT_TRUE(edge_identity1_identity2);
- // Case 3b: we have control edge const0 -> identity3, and identity1 ->
- // identity3.
- bool edge_const0_identity3 = false, edge_identity1_identity3 = false;
- for (const Edge *e : g.edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
- if (e->src() == const0_node && e->dst() == identity3_node) {
- edge_const0_identity3 = true;
- } else if (e->src() == identity1_node && e->dst() == identity3_node) {
- edge_identity1_identity3 = true;
- }
- }
- EXPECT_TRUE(edge_const0_identity3);
- EXPECT_TRUE(edge_identity1_identity3);
-}
-
-TEST(PostprocessForEncapsulationTest, DataEdges) {
- // Build the graph:
- // "const0" in outside compilation "0"
- // "placeholder0" (for "const0") in host computation
- // "add0" = "placeholder0" + "placeholder0" in host computation
- // "placeholder1" (for "add0") in outside compilation 1
- // "add1" = "placeholder1" + "placeholder1" in outside compilation 1
- //
- // "bridge" = "placeholder0" in host computation
- // "placeholder2" (for "bridge") in outside compilation 1
- // "add2" = "placeholder2" + "placeholder2" in outside compilation 1
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
- Output placeholder0 =
- ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32);
- Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0);
- Output placeholder1 =
- ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32);
- Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1);
- Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0);
- Output placeholder2 =
- ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32);
- Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set related attributes.
- Node *placeholder0_node = node_index["placeholder0"];
- placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName,
- "const0");
- placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0);
- Node *placeholder1_node = node_index["placeholder1"];
- placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
- "add0");
- placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
- Node *bridge_node = node_index["bridge"];
- bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0");
- Node *placeholder2_node = node_index["placeholder2"];
- placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
- "bridge");
- placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
-
- std::unordered_map clusters;
- TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
-
- // Result graph should be:
- // "add0" = "const0" + "const0"
- // "add1" = "add0" + "add0"
- // "add2" = "const0" + "const0"
- node_index = g.BuildNodeNameIndex();
- EXPECT_EQ(node_index.size(), 6);
- EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0");
- EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0");
- EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0");
- EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0");
- EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0");
- EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0");
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors,
NodeDef def;
def.set_name(launch->name());
+ MergeDebugInfo(NodeDebugInfo(launch->def()), &def);
// Target the XLA CPU/GPU backends.
VLOG(2) << "Replacing with XlaLaunch";
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index e3c7e2f89be9b37b51a633dabb099969c181013f..8b01768c49422b331b52a8ba31bade000c95722e 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -98,9 +100,12 @@ xla::StatusOr BuildRecvAtHostNode(
recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
- recv_at_host_builder.Attr("device_ordinal", 0);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
recv_at_host_builder.Attr(
"key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
Status s;
@@ -197,9 +202,12 @@ xla::StatusOr BuildSendFromHostNode(
send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
- send_from_host_builder.Attr("device_ordinal", 0);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
send_from_host_builder.Attr(
"key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
std::vector inputs(send_from_host_dtypes.size());
for (auto* n : ret_nodes) {
int index;
@@ -322,6 +330,38 @@ xla::StatusOr BuildXlaHostComputeNodeDef(
return new_def;
}
+Status ValidateOutsideCompilationCallNode(Node* call_node) {
+ // DT_INT64 as input/output for outside compilation is not supported yet:
+ // b/120809951.
+ for (const Edge* e : call_node->in_edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+ DataType dtype = e->src()->output_type(e->src_output());
+ if (dtype == DT_INT64) {
+ return errors::Unimplemented(
+ "int64 input for outside compilation is not supported yet: "
+ "b/120809951. Please cast output of node ",
+ e->src()->DebugString(),
+ " to int32 before feeding it into outside compilation.");
+ }
+ }
+ for (const Edge* e : call_node->out_edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+ DataType dtype = e->dst()->input_type(e->dst_input());
+ if (dtype == DT_INT64) {
+ return errors::Unimplemented(
+ "int64 output for outside compilation is not supported yet: "
+ "b/120809951. Please cast input of node ",
+ e->dst()->DebugString(),
+ " to int32 before returning it from outside compilation.");
+ }
+ }
+ return Status::OK();
+}
+
// Replace outside compilation function call node with XlaHostCompute node.
// If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node.
@@ -357,6 +397,47 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
return Status::OK();
}
+// Resets "device_ordinal" attr to placeholder value for related nodes
+// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing
+// XlaRecvAtHost/XlaSendFromHost).
+Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ for (Node* n : g->nodes()) {
+ if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
+ continue;
+ }
+
+ if (n->type_string() == "_XlaRecvAtHost" ||
+ n->type_string() == "_XlaSendFromHost") {
+ n->ClearAttr("device_ordinal");
+ n->AddAttr("device_ordinal", device_ordinal_value);
+ } else if (n->type_string() == "If") {
+ for (const string& attr_name :
+ std::vector{"then_branch", "else_branch"}) {
+ NameAttrList branch_func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
+ (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ n->ClearAttr(attr_name);
+ n->AddAttr(attr_name, branch_func);
+ }
+ } else if (n->type_string() == "While") {
+ for (const string& attr_name : std::vector{"cond", "body"}) {
+ NameAttrList branch_func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
+ (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ n->ClearAttr(attr_name);
+ n->AddAttr(attr_name, branch_func);
+ }
+ } else {
+ return errors::Internal("Unknown node marked with ",
+ kXlaHasHostTransferAttrName, ": ",
+ n->DebugString());
+ }
+ }
+ return Status::OK();
+}
+
// For an XLA computation, builds host side graph given all outside compilation
// graphs inside it. The host side graph contains:
// 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
@@ -368,8 +449,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
Status ConstructHostGraph(
const string& xla_cluster_name, const string& outside_compilation_attr_name,
const std::vector& outside_compilation_host_graphs,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) {
- host_graph->reset(new Graph(fld));
+ FunctionLibraryDefinition* fld, const string& host_graph_func_name) {
+ Graph host_graph(fld);
// Create sequencer node in host graph.
NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
@@ -378,24 +459,34 @@ Status ConstructHostGraph(
NodeDef sequencer_def;
TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
Status s;
- Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
+ Node* sequencer = host_graph.AddNode(sequencer_def, &s);
TF_RETURN_IF_ERROR(s);
// Create key placeholder in host graph.
TF_ASSIGN_OR_RETURN(
Node * key_placeholder,
- AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
// For each outside compilation graph, copy them to host graph with the
// following changes:
// a) Use key_placeholder in host graph instead of its own.
- // b) Add control edge from RecvAtHost/SendFromHost to sequencer.
+ // b) Add control edge from host transfer nodes (XlaRecvAtHost,
+ // XlaSendFromHost, If/While nodes containing
+ // XlaRecvAtHost/XlaSendFromHost) to sequencer node.
// c) Clear node_def.device(), so device placer won't get confused.
for (const string& host_func : outside_compilation_host_graphs) {
VLOG(4) << "Expanding host graph " << host_func;
+ // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
+ // value after we expanded all host graphs. We cannot just use placeholder
+ // value here because FunctionDef instantiation does not allow placeholder
+ // value for attributes.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* host_fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(host_func), AttrSlice(), fld,
+ *fld->Find(host_func), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
@@ -408,8 +499,8 @@ Status ConstructHostGraph(
FixupSourceAndSinkEdges(host_fbody->graph);
std::map node_map;
- node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
- node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
+ node_map[host_fbody->graph->source_node()] = host_graph.source_node();
+ node_map[host_fbody->graph->sink_node()] = host_graph.sink_node();
Status s;
ReverseDFS(
*host_fbody->graph, /*enter=*/nullptr,
@@ -431,7 +522,7 @@ Status ConstructHostGraph(
NodeDef copy_def = n->def();
// Change c).
copy_def.clear_device();
- copy = (*host_graph)->AddNode(copy_def, &s);
+ copy = host_graph.AddNode(copy_def, &s);
if (!s.ok()) {
return;
}
@@ -446,22 +537,23 @@ Status ConstructHostGraph(
e->src()->DebugString());
return;
}
- (*host_graph)
- ->AddEdge(node_map[e->src()], e->src_output(), copy,
- e->dst_input());
+ host_graph.AddEdge(node_map[e->src()], e->src_output(), copy,
+ e->dst_input());
}
// Change b).
- if (copy->type_string() == "_XlaRecvAtHost" ||
- copy->type_string() == "_XlaSendFromHost") {
- (*host_graph)->AddControlEdge(copy, sequencer);
+ if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
+ host_graph.AddControlEdge(copy, sequencer);
}
},
NodeComparatorID());
+
if (!s.ok()) {
return s;
}
}
+ // Reset "device_ordinal" to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph));
// sequencer and key_placeholder might be dead nodes. Prune them if necessary.
// - sequencer should be pruned iff it has no input control edges from
@@ -470,21 +562,30 @@ Status ConstructHostGraph(
// - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
// We don't need to do anything special.
if (!sequencer->in_edges().empty()) {
- (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
+ host_graph.AddControlEdge(sequencer, host_graph.sink_node());
}
PruneForReverseReachability(
- host_graph->get(),
- std::unordered_set{(*host_graph)->sink_node()});
+ &host_graph, std::unordered_set{host_graph.sink_node()});
// Postprocess edges between different outside compilations.
TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
- host_graph->get(), outside_compilation_attr_name));
+ &host_graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_",
xla_cluster_name),
- **host_graph, fld);
+ host_graph, fld);
+ }
+
+ FunctionDef host_graph_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef));
}
return Status::OK();
@@ -492,8 +593,28 @@ Status ConstructHostGraph(
// Expand XLA computation's outside compilation host side graph into main graph.
// Add a control edge between sequencer node and the XLA computation node.
-Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph,
+Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
+ FunctionLibraryDefinition* fld,
+ const string& host_graph_func_name,
Node* xla_computation_node) {
+ // Temporarily use "0" as "device_ordinal". It will be rewritten with the
+ // correct value in a later pass. We cannot just use placeholder value here
+ // because FunctionDef instantiation does not allow placeholder value for
+ // attributes.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* host_graph = fbody->graph;
+
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
// reachable from sink node so all nodes will be copied.
// TODO(b/77601805): consolidate copy graph functions.
@@ -545,23 +666,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph,
return s;
}
-// Rewrites shape inference graph for outside compilation.
-// 1. If the outside compilation is a "top-level" one (not in a function of any
-// If/While/etc.), this shape inference graph might have host computation to
-// outside compilation placeholder nodes, which will cause shape inference to
-// fail. However, those nodes are not in `host_graph` any more (because we
-// have executed `PostprocessForEncapsultion`). In this case, we clear the
-// graph, and copy SendFromHost with all its predecessors from `host_graph`.
-// This case is detected by whether the SendFromHost node exists in
-// `host_graph` as well.
-// 2. Remove control edges, and prune nodes that are not useful for shape
-// inference.
+// Rewrites shape inference graph for outside compilation:
+// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
+// `host_graph`. Because we might still have outside compilation to outside
+// compilation placeholder nodes in shape inference graph, which will prevent
+// us from inferring XlaSendFromHost shape. But in `host_graph`, we already
+// removed those placeholder nodes.
+// 2) Remove control edges.
+// 3) Prune nodes that are not useful for shape inference.
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
Graph* host_graph,
FunctionLibraryDefinition* fld) {
+ // Use "0" as "device_ordinal". It does not matter for shape inference.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(shape_inference_graph_name), AttrSlice(), fld,
+ *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
@@ -650,6 +773,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
g->RemoveEdge(e);
}
}
+
// Nodes that are not reverse reachable from SendFromHost are not useful for
// shape inference. Prune them.
PruneForReverseReachability(g,
@@ -669,6 +793,572 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
return Status::OK();
}
+// Builds XlaSendToHost node which sends cond predicate to host.
+xla::StatusOr BuildSendIfPredNode(const string& name,
+ const string& host_transfer_key,
+ Node* pred_node, Graph* g) {
+ NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
+ send_pred_builder.Attr("Tinput", DT_BOOL);
+ send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
+ send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+ send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
+ NodeDef send_pred_def;
+ TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
+ Status s;
+ Node* send_pred_node = g->AddNode(send_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ g->AddEdge(pred_node, 0, send_pred_node, 0);
+ return send_pred_node;
+}
+
+// Replaces key placeholder node with an _Arg node.
+Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
+ const string& func_name,
+ FunctionLibraryDefinition* fld) {
+ // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
+ // value after rewriting.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* g = fbody->graph;
+
+ // Find or create the key placeholder node.
+ Node* key_placeholder = nullptr;
+ for (Node* n : g->nodes()) {
+ if (IsKeyPlaceholderNode(*n)) {
+ key_placeholder = n;
+ break;
+ }
+ }
+ if (!key_placeholder) {
+ TF_ASSIGN_OR_RETURN(key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, g));
+ }
+
+ // Build the _Arg node, and replace key placeholder node with it.
+ NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
+ arg_builder.Attr("T", DT_STRING);
+ arg_builder.Attr("index", 0);
+ NodeDef arg_def;
+ TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
+ TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
+
+ // Reset "device_ordinal" to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
+
+ FunctionDef replace_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef));
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
+ return Status::OK();
+}
+
+// Builds host side graph for If node.
+Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name,
+ const string& xla_cluster_name,
+ const string& if_node_name,
+ const string& host_transfer_key,
+ const string& host_graph_func_name,
+ FunctionLibraryDefinition* fld,
+ const string& then_branch_host_func_name,
+ const string& else_branch_host_func_name) {
+ Graph host_graph(fld);
+ string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+
+ // Step 1: add key placeholder node.
+ TF_ASSIGN_OR_RETURN(
+ Node * key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
+
+ // Step 2: build XlaRecvAtHost node to recv predicate.
+ NodeDefBuilder recv_pred_builder(
+ absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
+ recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL});
+ recv_pred_builder.Attr("key", host_transfer_key);
+ recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
+ recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ recv_pred_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_name);
+ recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
+ recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
+ NodeDef recv_pred_def;
+ TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
+ Status s;
+ Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
+
+ // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
+ // placeholder with an _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, then_branch_host_func_name, fld));
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, else_branch_host_func_name, fld));
+
+ // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
+ NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
+ if_builder.Attr("Tcond", DT_BOOL);
+ if_builder.Attr("Tin", std::vector{DT_STRING});
+ if_builder.Attr("Tout", std::vector{});
+ NameAttrList host_then_branch, host_else_branch;
+ host_then_branch.set_name(then_branch_host_func_name);
+ (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ host_else_branch.set_name(else_branch_host_func_name);
+ (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ if_builder.Attr("then_branch", host_then_branch);
+ if_builder.Attr("else_branch", host_else_branch);
+ if_builder.Attr(kXlaHasHostTransferAttrName, true);
+ if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
+ if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
+ std::vector if_inputs{
+ {key_placeholder->name(), 0, DT_STRING}};
+ if_builder.Input(if_inputs);
+ NodeDef if_def;
+ TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
+ Node* if_node = host_graph.AddNode(if_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
+ host_graph.AddEdge(key_placeholder, 0, if_node, 1);
+
+ // Convert `host_graph` to function, and add a "device_ordinal" attr.
+ FunctionDef oc_host_graph_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
+ &oc_host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
+ }
+
+ return Status::OK();
+}
+
+// Rewrites loop cond to add a node which sends loop cond to host.
+Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
+ const NameAttrList& loop_cond_func,
+ const string& while_node_name,
+ const string& host_transfer_key) {
+ // Instantiate the loop cond function.
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* g = fbody->graph;
+
+ // Find the _Retval node and the loop cond node.
+ Node* ret_node = nullptr;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "_Retval") {
+ if (ret_node) {
+ return errors::Internal("Multiple return node for loop cond function ",
+ loop_cond_func.name(), ": ",
+ ret_node->DebugString(), " and ",
+ n->DebugString());
+ } else {
+ ret_node = n;
+ }
+ }
+ }
+ if (!ret_node) {
+ return errors::Internal("No _Retval node for loop cond function ",
+ loop_cond_func.name());
+ }
+ Node* loop_cond;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
+
+ // Build the XlaSendToHost node.
+ NodeDefBuilder send_loop_cond_builder(
+ absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost");
+ send_loop_cond_builder.Attr("Tinput", DT_BOOL);
+ send_loop_cond_builder.Attr("key",
+ absl::StrCat(host_transfer_key, "_dtoh_0"));
+ send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+ send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
+ NodeDef send_loop_cond_def;
+ TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
+ Status s;
+ Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
+
+ // Replace original function.
+ FunctionDef replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef));
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef));
+
+ return Status::OK();
+}
+
+// Rewrites while loop cond function for host.
+Status RewriteHostWhileLoopCond(
+ const string& cond_host_func_name, const string& while_node_name,
+ const string& host_transfer_key, const string& xla_cluster_attr_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
+ const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
+ // Replace key placeholder node with _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, cond_host_func_name, fld));
+
+ // Instantiate cond function.
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_temp_value;
+ FunctionBody* cond_fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &cond_fbody));
+ std::unique_ptr cond_fbody_deleter(cond_fbody);
+ Graph* cond_graph = cond_fbody->graph;
+ Node* key_arg = nullptr;
+ for (Node* n : cond_graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ key_arg = n;
+ }
+ }
+ if (!key_arg) {
+ return errors::Internal(
+ "No _Arg node found for host compute key in function ",
+ cond_host_func_name);
+ }
+
+ // Add an XlaRecvAtHost node to use as cond function return value.
+ // We don't need to set kXlaHasHostTransferAttrName for this node, because
+ // it's already added for the "While" node on the host.
+ NodeDefBuilder recv_pred_builder(
+ absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
+ recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL});
+ recv_pred_builder.Attr("key", host_transfer_key);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
+ recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ recv_pred_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_name);
+ recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
+ NodeDef recv_pred_def;
+ TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
+ Status s;
+ Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
+ NodeDefBuilder ret_builder(
+ absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
+ ret_builder.Attr("T", DT_BOOL);
+ ret_builder.Attr("index", 0);
+ ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
+ NodeDef ret_def;
+ TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
+ Node* ret_node = cond_graph->AddNode(ret_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
+
+ // Reset device_ordinal to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
+
+ // Replace original function.
+ FunctionDef cond_replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef));
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
+
+ return Status::OK();
+}
+
+// Rewrites while loop body function for host.
+Status RewriteHostWhileLoopBody(
+ const string& body_host_func_name, const string& while_node_name,
+ const string& host_transfer_key, const string& xla_cluster_attr_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
+ const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
+ // Replace key placeholder node with _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, body_host_func_name, fld));
+
+ // Instantiate body function.
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_temp_value;
+ FunctionBody* body_fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(body_host_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &body_fbody));
+ std::unique_ptr body_fbody_deleter(body_fbody);
+ Graph* body_graph = body_fbody->graph;
+ Node* key_arg = nullptr;
+ for (Node* n : body_graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ key_arg = n;
+ }
+ }
+ if (!key_arg) {
+ return errors::Internal(
+ "No _Arg node found for host compute key in function ",
+ body_host_func_name);
+ }
+
+ // Add a _Retval node to loop body.
+ NodeDefBuilder ret_builder(
+ absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
+ ret_builder.Attr("T", DT_STRING);
+ ret_builder.Attr("index", 0);
+ ret_builder.Input(key_arg->name(), 0, DT_STRING);
+ NodeDef ret_def;
+ TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
+ Status s;
+ Node* ret_node = body_graph->AddNode(ret_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ body_graph->AddEdge(key_arg, 0, ret_node, 0);
+
+ // Reset device_ordinal to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
+
+ // Replace original function.
+ FunctionDef body_replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef));
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
+
+ return Status::OK();
+}
+
+// Builds host side graph for while node.
+Status BuildHostGraphForWhileNode(
+ const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name, const string& xla_cluster_name,
+ const string& while_node_name, const string& host_transfer_key,
+ const string& host_graph_func_name, FunctionLibraryDefinition* fld,
+ const string& cond_host_func_name, const string& body_host_func_name) {
+ Graph host_graph(fld);
+ string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
+
+ // Step 1: add key placeholder node.
+ TF_ASSIGN_OR_RETURN(
+ Node * key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
+
+ // Step 2: rewrite cond function.
+ TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
+ cond_host_func_name, while_node_name, host_transfer_key,
+ xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_name, fld));
+
+ // Step 3: rewrite body function.
+ TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
+ body_host_func_name, while_node_name, host_transfer_key,
+ xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_name, fld));
+
+ // Step 4: build While node.
+ NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
+ "While");
+ while_builder.Attr("T", std::vector{DT_STRING});
+ NameAttrList func;
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ (*func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ func.set_name(cond_host_func_name);
+ while_builder.Attr("cond", func);
+ func.set_name(body_host_func_name);
+ while_builder.Attr("body", func);
+ while_builder.Attr(kXlaHasHostTransferAttrName, true);
+ while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
+ std::vector while_inputs{
+ {key_placeholder->name(), 0, DT_STRING}};
+ while_builder.Input(while_inputs);
+ NodeDef while_def;
+ TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
+ Status s;
+ Node* while_node = host_graph.AddNode(while_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(key_placeholder, 0, while_node, 0);
+
+ // Convert `host_graph` to function.
+ FunctionDef oc_host_graph_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
+ &oc_host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
+ }
+
+ return Status::OK();
+}
+
+Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
+ Graph* g, const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name, const string& xla_cluster_name,
+ const std::map& host_compute_core,
+ FunctionLibraryDefinition* fld, std::vector* host_graphs,
+ std::vector* shape_inference_graphs,
+ bool* has_outside_compilation) {
+ std::vector if_nodes, while_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "If") {
+ if_nodes.push_back(n);
+ } else if (n->type_string() == "While") {
+ while_nodes.push_back(n);
+ }
+ }
+
+ for (Node* n : if_nodes) {
+ // Instantiate "then_branch" and "else_branch".
+ NameAttrList then_branch, else_branch;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
+
+ // Extract outside compilation for then_branch and else_branch.
+ bool then_branch_has_outside_compilation = false;
+ bool else_branch_has_outside_compilation = false;
+ string then_branch_host_func_name =
+ absl::StrCat("oc_then_branch_host_if_", n->name()),
+ else_branch_host_func_name =
+ absl::StrCat("oc_else_branch_host_if_", n->name());
+ string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
+ else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ then_branch, then_branch_xla_func_name, then_branch_host_func_name,
+ host_compute_core, fld, shape_inference_graphs,
+ &then_branch_has_outside_compilation));
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ else_branch, else_branch_xla_func_name, else_branch_host_func_name,
+ host_compute_core, fld, shape_inference_graphs,
+ &else_branch_has_outside_compilation));
+
+ // If then/else branch do not have outside compilation, nothing to do.
+ if (!then_branch_has_outside_compilation &&
+ !else_branch_has_outside_compilation) {
+ continue;
+ }
+
+ *has_outside_compilation = true;
+
+ // Change If node to call the new functions.
+ then_branch.set_name(then_branch_xla_func_name);
+ n->ClearAttr("then_branch");
+ n->AddAttr("then_branch", then_branch);
+ else_branch.set_name(else_branch_xla_func_name);
+ n->ClearAttr("else_branch");
+ n->AddAttr("else_branch", else_branch);
+
+ string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
+
+ // XLA computation: add a SendToHost node to send cond predicate.
+ Node* pred_node;
+ TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
+ TF_ASSIGN_OR_RETURN(
+ Node * send_pred_node,
+ BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
+ host_transfer_key, pred_node, g));
+ n->AddAttr(kXlaTokenInputNodesAttrName,
+ std::vector{send_pred_node->name()});
+
+ // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
+ // visit If node after `send_pred_node`, thus the token output for
+ // `send_pred_node` has been generated.
+ g->AddControlEdge(send_pred_node, n);
+
+ // Build host side graph for the "If" node.
+ string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
+ TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ n->name(), host_transfer_key, oc_host_graph_name, fld,
+ then_branch_host_func_name, else_branch_host_func_name));
+ host_graphs->push_back(oc_host_graph_name);
+ }
+
+ for (Node* n : while_nodes) {
+ // Instantiate "cond" and "body".
+ NameAttrList cond, body;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
+
+ // Extract outside compilation for cond and body.
+ bool cond_has_outside_compilation = false;
+ bool body_has_outside_compilation = false;
+ string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
+ body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
+ string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
+ body_xla_func_name = absl::StrCat(body.name(), "_oc");
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld,
+ shape_inference_graphs, &cond_has_outside_compilation));
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ body, body_xla_func_name, body_host_func_name, host_compute_core, fld,
+ shape_inference_graphs, &body_has_outside_compilation));
+
+ // If cond/body do not have outside compilation, nothing to do.
+ if (!cond_has_outside_compilation && !body_has_outside_compilation) {
+ continue;
+ }
+
+ *has_outside_compilation = true;
+
+ // Change While node to call the new functions.
+ cond.set_name(cond_xla_func_name);
+ n->ClearAttr("cond");
+ n->AddAttr("cond", cond);
+ body.set_name(body_xla_func_name);
+ n->ClearAttr("body");
+ n->AddAttr("body", body);
+
+ string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
+
+ // XLA computation: rewrite cond function to add a SendToHost node to send
+ // loop predicate.
+ TF_RETURN_IF_ERROR(
+ AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
+ n->AddAttr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+
+ // Build host side graph for the "While" node.
+ string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
+ TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ n->name(), host_transfer_key, oc_host_graph_name, fld,
+ cond_host_func_name, body_host_func_name));
+ host_graphs->push_back(oc_host_graph_name);
+ }
+
+ return Status::OK();
+}
+
} // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()(
@@ -755,12 +1445,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()(
// it with HostCompute node later.
AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
if (shapes) {
- AddNodeAttr("shape_inference_graph", "", node_def);
+ NameAttrList shape_inference_graph;
+ AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
AddNodeAttr("shapes", *shapes, node_def);
} else {
string shape_inference_func_name =
absl::StrCat("_outside_compilation_shape_inference_", new_name);
- AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def);
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name(shape_inference_func_name);
+ AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
AddNodeAttr("shapes", std::vector{}, node_def);
}
AddNodeAttr("ancestors", std::vector{}, node_def);
@@ -775,11 +1468,10 @@ Status ExtractOutsideCompilationForFunction(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& xla_cluster_name,
const NameAttrList& func_name_attrs, const string& new_func_name,
+ const string& host_graph_func_name,
const std::map& host_compute_core,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph,
- std::vector* shape_inference_graphs,
+ FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs,
bool* has_outside_compilation) {
- // Early return if function does not have any outside compilation nodes.
const string& func_name = func_name_attrs.name();
const FunctionDef* fdef = fld->Find(func_name);
if (!fdef) {
@@ -792,9 +1484,8 @@ Status ExtractOutsideCompilationForFunction(
break;
}
}
- if (!has_outside_compilation) {
- return Status::OK();
- }
+ // We cannot early return here, because we might have outside compilation in
+ // If/While function body.
// Convert the function to graph.
FunctionBody* fbody = nullptr;
@@ -835,11 +1526,11 @@ Status ExtractOutsideCompilationForFunction(
// If we could not infer shapes for XlaSendFromHost inputs statically, we
// will set the "shape_inference_graph" attribute. In that case, copy
// outside compilation subgraph as shape inference graph in `fld`.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
&shape_inference_graph));
- if (!shape_inference_graph.empty()) {
- shape_inference_graphs->push_back(shape_inference_graph);
+ if (!shape_inference_graph.name().empty()) {
+ shape_inference_graphs->push_back(shape_inference_graph.name());
const FunctionDef* xla_fdef = fld->Find(n->name());
if (!xla_fdef) {
@@ -847,9 +1538,9 @@ Status ExtractOutsideCompilationForFunction(
}
FunctionDef shape_inference_fdef = *xla_fdef;
shape_inference_fdef.mutable_signature()->set_name(
- shape_inference_graph);
- if (fld->Find(shape_inference_graph)) {
- TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph,
+ shape_inference_graph.name());
+ if (fld->Find(shape_inference_graph.name())) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(),
shape_inference_fdef));
} else {
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
@@ -858,6 +1549,7 @@ Status ExtractOutsideCompilationForFunction(
}
}
for (Node* n : outside_compilation_nodes) {
+ TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core));
}
@@ -867,12 +1559,17 @@ Status ExtractOutsideCompilationForFunction(
*graph_out, fld);
}
+ // Handle nodes with associated functions.
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
+ graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name,
+ xla_cluster_name, host_compute_core, fld,
+ &outside_compilation_host_graphs, shape_inference_graphs,
+ has_outside_compilation));
+
// Construct host graph.
- if (!outside_compilation_host_graphs.empty()) {
- TF_RETURN_IF_ERROR(
- ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
- outside_compilation_host_graphs, fld, host_graph));
- }
+ TF_RETURN_IF_ERROR(ConstructHostGraph(
+ xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_host_graphs, fld, host_graph_func_name));
// Remove the outside compilation graphs from function library.
for (const string& func : outside_compilation_host_graphs) {
@@ -909,24 +1606,17 @@ Status ExtractOutsideCompilation(
auto const& host_compute_core = iter.second.host_compute_core;
bool has_outside_compilation;
- std::unique_ptr host_graph;
+ string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
- func_name_attrs, func_name_attrs.name(), host_compute_core, fld,
- &host_graph, &shape_inference_graphs, &has_outside_compilation));
- if (host_graph) {
- TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n));
- }
- }
-
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g,
- fld);
+ func_name_attrs, func_name_attrs.name(), host_graph_func_name,
+ host_compute_core, fld, &shape_inference_graphs,
+ &has_outside_compilation));
+ TF_RETURN_IF_ERROR(
+ ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
+ TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
}
- TF_RETURN_IF_ERROR(PostprocessForEncapsulation(
- g, xla_cluster_attr_name, outside_compilation_attr_name, clusters));
-
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h
index 2a4f07cca213d999202024294f5d8f94527059c3..e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h
@@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& xla_cluster_name,
const NameAttrList& func_name_attrs, const string& new_func_name,
+ const string& host_graph_func_name,
const std::map& host_compute_core,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph,
- std::vector* shape_inference_graphs, bool* has_outside_compilation);
+ FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs,
+ bool* has_outside_compilation);
// Rewrites XLA computation in `clusters` to replace outside compilation nodes
// with XlaHostCompute, and moves those outside compilations into `g`. If shapes
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index bff956100da661b679b4557fce53671e6cef88c5..e9a89e34e0c7b04b4be34e367b2d0bf627c0061a 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -19,8 +19,10 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
@@ -109,10 +111,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) {
}
EXPECT_TRUE(has_control_edge_to_send_from_host);
// Verify step 7: necessary attrs added to call_node_def.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()),
"shape_inference_graph", &shape_inference_graph));
- EXPECT_EQ(shape_inference_graph,
+ EXPECT_EQ(shape_inference_graph.name(),
"_outside_compilation_shape_inference_cluster_0");
}
@@ -249,27 +251,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
- FunctionBody *fbody = nullptr;
- TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
- AttrSlice(), &fld,
- [&](const string &op, const OpDef **sig) {
- return fld.LookUpOpDef(op, sig);
- },
- &fbody));
- std::unique_ptr fbody_deleter(fbody);
- auto node_name_index = fbody->graph->BuildNodeNameIndex();
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
@@ -292,18 +293,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
EXPECT_EQ(shapes[0].dim_size(), 1);
// Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
// empty values.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph, "");
+ EXPECT_EQ(shape_inference_graph.name(), "");
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph, "");
+ EXPECT_EQ(shape_inference_graph.name(), "");
// Check `shape_inference_graphs`.
EXPECT_EQ(shape_inference_graphs.size(), 0);
- // Check `host_graph`: verify we have key placeholder and sequencer.
+ // Check host graph: verify we have key placeholder and sequencer.
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
Node *key_placeholder = nullptr, *sequencer = nullptr;
for (Node *n : host_graph->nodes()) {
if (n->type_string() == "Placeholder" &&
@@ -365,25 +379,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
- // Check `host_graph` is empty.
- EXPECT_FALSE(host_graph);
+ // Check host graph is empty.
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ EXPECT_EQ(host_graph->num_nodes(), 2);
}
TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
// Build the XLA computation func.
// "const0"
- // "const1" (outside compilation clsuter "0")
+ // "const1" (outside compilation cluster "0")
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -401,31 +427,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Check rewritten XLA graph: verify that we have no XlaHostCompute.
- FunctionBody *fbody = nullptr;
- TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
- AttrSlice(), &fld,
- [&](const string &op, const OpDef **sig) {
- return fld.LookUpOpDef(op, sig);
- },
- &fbody));
- std::unique_ptr fbody_deleter(fbody);
- for (Node *n : fbody->graph->nodes()) {
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ for (Node *n : xla_fbody->graph->nodes()) {
EXPECT_NE(n->type_string(), "XlaHostCompute");
}
- // Check `host_graph`: verify we have no placeholder, but we have "const1".
+ // Check host graph: verify we have no placeholder, but we have "const1".
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
int num_key_placeholders = 0;
for (Node *n : host_graph->nodes()) {
if (n->type_string() == "Placeholder" &&
@@ -438,4 +476,310 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
EXPECT_NE(node_name_index.find("const1"), node_name_index.end());
}
+REGISTER_OP("XlaSendToHost")
+ .Input("input: Tinput")
+ .Attr("Tinput: type")
+ .Attr("key: string")
+ .SetIsStateful();
+
+REGISTER_OP("XlaRecvFromHost")
+ .Output("output: Toutput")
+ .Attr("Toutput: type")
+ .Attr("shape: shape")
+ .Attr("key: string")
+ .SetIsStateful();
+
+TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
+ // Build the XLA computation func.
+ // "const0" (bool)
+ // "const1" (int32)
+ // "if0" (pred = "const0", input = "const1", then_branch = "true_fn",
+ // else_branch = "false_fn")
+ FunctionDefLibrary fdl;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_true_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_true_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *true_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_false_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_false_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *false_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output cond = ops::Const(s.WithOpName("const0"), true, {2});
+ Output input = ops::Const(s.WithOpName("const1"), 1, {2});
+ NameAttrList true_fn;
+ true_fn.set_name("true_fn");
+ NameAttrList false_fn;
+ false_fn.set_name("false_fn");
+ auto if_op = ops::If(s.WithOpName("if"), cond,
+ std::initializer_list{cond, input}, {DT_INT32},
+ true_fn, false_fn);
+ ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+
+ FunctionDef *xla_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
+ }
+ FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
+
+ protobuf::Map attrs;
+ std::map host_compute_core;
+ std::vector shape_inference_graphs;
+ bool has_outside_compilation;
+ NameAttrList name_attrs;
+ name_attrs.set_name("cluster");
+ *name_attrs.mutable_attr() = attrs;
+ TF_CHECK_OK(ExtractOutsideCompilationForFunction(
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
+ &has_outside_compilation));
+
+ // Check host graph.
+ {
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ auto node_name_index = host_graph->BuildNodeNameIndex();
+
+ // Verify we have XlaRecvAtHost to receive "If" predicate.
+ Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"];
+ EXPECT_NE(recv_if_pred_node, nullptr);
+
+ // Verify we have an "If" to choose outside compilation between then_branch
+ // and else_branch, and it has `recv_if_pred_node` as cond input.
+ Node *if_oc_node = node_name_index["oc_if_if"];
+ EXPECT_NE(if_oc_node, nullptr);
+ Node *if_oc_node_cond_input;
+ TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input));
+ EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node);
+
+ // Check that then_branch outside compilation has node "identity_true_fn".
+ const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if");
+ EXPECT_NE(true_def, nullptr);
+ bool has_identity_true_fn_node = false;
+ for (const auto &node_def : true_def->node_def()) {
+ if (node_def.name() == "identity_true_fn") {
+ has_identity_true_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_true_fn_node);
+
+ // Check that else_branch outside compilation has node "identity_false_fn".
+ const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if");
+ EXPECT_NE(false_def, nullptr);
+ bool has_identity_false_fn_node = false;
+ for (const auto &node_def : false_def->node_def()) {
+ if (node_def.name() == "identity_false_fn") {
+ has_identity_false_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_false_fn_node);
+ }
+
+ // Check XLA graph.
+ {
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ Graph *xla_graph = xla_fbody->graph;
+ auto node_name_index = xla_graph->BuildNodeNameIndex();
+
+ // Check that we have XlaSendToHost to send cond predicate to host, and
+ // there is a control edge to If node.
+ Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"];
+ EXPECT_NE(send_if_pred_node, nullptr);
+ bool has_control_edge_to_if = false;
+ for (const Edge *e : send_if_pred_node->out_edges()) {
+ if (e->IsControlEdge() && e->dst()->name() == "if") {
+ has_control_edge_to_if = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_control_edge_to_if);
+
+ // Check that the "If" node now has `send_if_pred_node` as attribute
+ // _xla_token_input_nodes.
+ Node *if_node = node_name_index["if"];
+ EXPECT_NE(if_node, nullptr);
+ std::vector token_inputs;
+ TF_CHECK_OK(
+ GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs));
+ EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if"));
+ }
+}
+
+TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) {
+ // Build the XLA computation func.
+ // "const0" (bool)
+ // "while0" (input = "const0", cond = "cond_fn", body = "body_fn")
+ FunctionDefLibrary fdl;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_cond_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_cond_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *cond_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_body_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_body_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *body_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Const(s.WithOpName("const0"), true, {2});
+ NameAttrList cond_fn;
+ cond_fn.set_name("cond_fn");
+ NameAttrList body_fn;
+ body_fn.set_name("body_fn");
+ auto while_op =
+ ops::While(s.WithOpName("while"), std::initializer_list{input},
+ cond_fn, body_fn);
+ ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+
+ FunctionDef *xla_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
+ }
+ FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
+
+ protobuf::Map attrs;
+ std::map host_compute_core;
+ std::vector shape_inference_graphs;
+ bool has_outside_compilation;
+ NameAttrList name_attrs;
+ name_attrs.set_name("cluster");
+ *name_attrs.mutable_attr() = attrs;
+ TF_CHECK_OK(ExtractOutsideCompilationForFunction(
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
+ &has_outside_compilation));
+
+ // Check host graph.
+ {
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ auto node_name_index = host_graph->BuildNodeNameIndex();
+
+ // Verify we have an "While" to execute outside compilation.
+ Node *while_oc_node = node_name_index["oc_while_while"];
+ EXPECT_NE(while_oc_node, nullptr);
+
+ // Check that cond outside compilation has node "identity_cond_fn".
+ const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while");
+ EXPECT_NE(cond_def, nullptr);
+ bool has_identity_cond_fn_node = false;
+ for (const auto &node_def : cond_def->node_def()) {
+ if (node_def.name() == "identity_cond_fn") {
+ has_identity_cond_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_cond_fn_node);
+
+ // Check that body outside compilation has node "identity_body_fn".
+ const FunctionDef *body_def = fld.Find("oc_body_host_while_while");
+ EXPECT_NE(body_def, nullptr);
+ bool has_identity_body_fn_node = false;
+ for (const auto &node_def : body_def->node_def()) {
+ if (node_def.name() == "identity_body_fn") {
+ has_identity_body_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_body_fn_node);
+ }
+
+ // Check XLA graph.
+ {
+ // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to
+ // host.
+ const FunctionDef *cond_def = fld.Find("cond_fn_oc");
+ EXPECT_NE(cond_def, nullptr);
+ bool has_send_oc_while_cond_node = false;
+ for (const auto &node_def : cond_def->node_def()) {
+ if (node_def.name() == "send_oc_while_cond_while") {
+ has_send_oc_while_cond_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_send_oc_while_cond_node);
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 25796435a5c87af5e252981abf96833f4cda9a5e..6618e3a58ab7b6374ed775cd6e4e18a6a4975588 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -86,7 +86,7 @@ bool IsDummyImplOp(absl::string_view op_name) {
bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
- op_name == "TruncatedNormal";
+ op_name == "TruncatedNormal" || op_name == "Multinomial";
}
bool OpProducesOrConsumesVariant(const Node& node) {
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
NodeDef ndef = n->def();
ndef.set_name(absl::StrCat(n->name(), "/declustered"));
+ MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
RemoveFromXlaCluster(&ndef);
Status s;
Node* cloned_node = graph->AddNode(ndef, &s);
diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc
index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644
--- a/tensorflow/compiler/jit/shape_inference.cc
+++ b/tensorflow/compiler/jit/shape_inference.cc
@@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph,
// shapes, even if no shape function is registered for a node.
Status status = shape_refiner->AddNode(n);
if (!status.ok()) {
- VLOG(1) << "Shape inference failed for node: " << status;
+ VLOG(1) << "Shape inference failed for node " << n->name() << ": "
+ << status;
+ } else {
+ shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
+ for (int i = 0; i < n->num_outputs(); i++) {
+ shape_inference::ShapeHandle handle = context->output(i);
+ VLOG(4) << "Output " << i << " for node " << n->name() << ": "
+ << context->DebugString(handle);
+ }
}
if (n->type_string() == "_Arg") {
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile(
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Notification n;
+ Status status;
ctx->op_device_context()->CopyDeviceTensorToCPU(
&device_tensor, "ConstantArgument",
reinterpret_cast(ctx->device()), &host_tensor,
- [&](Status status) { n.Notify(); });
+ [&](Status s) {
+ status = s;
+ n.Notify();
+ });
n.WaitForNotification();
+ if (!status.ok()) {
+ LOG(ERROR) << "Copying tensor of shape "
+ << device_tensor.shape().DebugString() << " from "
+ << ctx->device()->name() << "to CPU failed with "
+ << status.ToString();
+ return status;
+ }
constant_arguments[i] = host_tensor;
}
}
@@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map variable_args = GetVariables(ctx);
std::vector args;
+
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7df898ad12a15345f45fc96e0ec3d42b6e51731b..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -63,7 +63,19 @@ Status XlaCpuDeviceFactory::CreateDevices(
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false;
- devices->push_back(absl::make_unique(session_options, options));
+ auto device = absl::make_unique(session_options, options);
+
+ // Setting GpuDeviceInfo because eager runtime relies on the device
+ // context in tensorflow_gpu_device_info(). Also,
+ // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test.
+ // We need XlaCpuDevice to be treated not as CPU because it allocates
+ // XlaTensors, not regular Tensors.
+ Status status = device->UseGpuDeviceInfo();
+ if (!status.ok()) {
+ errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT);
+ return status;
+ }
+ devices->push_back(std::move(device));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext(
}
}
+void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
+ Device* device,
+ Tensor* output_tensor,
+ StatusCallback done) const {
+ done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
+}
+
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
+ void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
+ Tensor* output_tensor,
+ StatusCallback done) const override;
xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); }
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel {
.HostMemory("output") \
.TypeConstraint("T"), \
ArgOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \
\
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 944f732b99c0924a08932eda0aedd8c815cc51d0..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,10 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
+#include
#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
@@ -52,8 +55,35 @@ Status XlaGpuDeviceFactory::CreateDevices(
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
return Status::OK();
}
-
- for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
+ string allowed_gpus =
+ session_options.config.gpu_options().visible_device_list();
+ std::set gpu_ids;
+ int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount();
+ if (allowed_gpus.empty()) {
+ for (int i = 0; i < num_visible_devices; ++i) {
+ gpu_ids.insert(i);
+ }
+ } else {
+ // For loop below is copied from gpu/gpu_device.cc. It validates
+ // the visible_device_list and populates gpu_ids set.
+ const std::vector visible_devices =
+ absl::StrSplit(allowed_gpus, ',');
+ for (const string& platform_gpu_id_str : visible_devices) {
+ int32 platform_gpu_id;
+ if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) {
+ return errors::InvalidArgument(
+ "Could not parse entry in 'visible_device_list': '",
+ platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus);
+ }
+ if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) {
+ return errors::InvalidArgument(
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
+ "' but visible device count is ", num_visible_devices);
+ }
+ gpu_ids.insert(platform_gpu_id);
+ }
+ }
+ for (int i : gpu_ids) {
XlaDevice::Options options;
options.platform = platform.ValueOrDie();
options.device_name_prefix = name_prefix;
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer {
public:
XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size,
Allocator* allocator)
- : expected_size_(expected_size),
+ : TensorBuffer(const_cast(ptr)),
+ expected_size_(expected_size),
actual_size_(actual_size),
- allocator_(allocator) {
- data_ = const_cast(ptr);
- }
+ allocator_(allocator) {}
~XlaTensorBuffer() override {
- if (data_) {
- allocator_->DeallocateRaw(data_);
+ if (data()) {
+ allocator_->DeallocateRaw(data());
}
}
- void* data() const override { return data_; }
size_t size() const override { return expected_size_; }
TensorBuffer* root_buffer() override { return this; }
@@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer {
}
private:
- void* data_;
size_t expected_size_;
size_t actual_size_;
Allocator* allocator_;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index bc3d60b90e58b4018f1c52b09941dedba7ef348a..093b61629cd0b04d5d8488139b8d7262b739f86d 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -408,13 +408,6 @@ tf_xla_py_test(
name = "eager_test",
size = "large",
srcs = ["eager_test.py"],
- disabled_backends = [
- # TODO(b/78199195) Support XLA CPU devices in eager runtime
- "cpu",
- "cpu_ondemand",
- # TODO(b/78468222) Enable GPU backend
- "gpu",
- ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding)
- def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
- stride, padding):
+ def _CompareBackpropFilter(self,
+ input_sizes,
+ filter_sizes,
+ output_sizes,
+ stride,
+ padding,
+ data_format="NHWC"):
x0 = np.random.rand(*input_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
@@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
+ native_t0 = t0
+ native_t2 = t2
+ strides = [1, stride, stride, 1]
+
if use_xla:
+ if data_format == "NCHW":
+ # Transpose from NWHC input to NCHW
+ # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
+ native_t0 = array_ops.transpose(t0, [0, 3, 1, 2])
+ native_t2 = array_ops.transpose(t2, [0, 3, 1, 2])
+ strides = [1, 1, stride, stride]
with self.test_scope():
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
- t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
+ native_t0,
+ t1,
+ native_t2,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
else:
+ # For CPU, the format NCHW is not supported. Therefore we always use
+ # NHWC here.
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
- t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
+ native_t0, t1, native_t2, strides=strides, padding=padding)
ret = backprop.eval({t0: x0, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
@@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ input_size, "*", filter_size, "producing output", output_size,
+ "stride:", stride, "padding:", padding)
self._CompareBackpropFilter(input_size, filter_size, output_size,
stride, padding)
+ def testDepthwiseConv2DFilterGradFormatNCHWCompare(self):
+ for index, (input_size, filter_size, output_size, stride,
+ padding) in enumerate(ConfigsToTest()):
+ print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index,
+ "th config:", input_size, "*", filter_size, "producing output",
+ output_size, "stride:", stride, "padding:", padding)
+ self._CompareBackpropFilter(
+ input_size,
+ filter_size,
+ output_size,
+ stride,
+ padding,
+ data_format="NCHW")
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 4cf88fc523735cc2d22e085afb83790c7ebb48e4..28274ff799de2c85e1e80512cadbe0206cb640a4 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -319,7 +319,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
session.run(output)
self.assertRegexpMatches(
invalid_arg_error.exception.message,
- (r'^start_indices must be a vector with length equal to input rank, '
+ (r'start_indices must be a vector with length equal to input rank, '
r'but input rank is 3 and start_indices has shape \[2\].*'))
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
@@ -332,7 +332,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
session.run(output)
self.assertRegexpMatches(
invalid_arg_error.exception.message,
- (r'^size_indices must be a vector with length equal to input rank, '
+ (r'size_indices must be a vector with length equal to input rank, '
r'but input rank is 3 and size_indices has shape \[2\].*'))
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 25a84fb1b6609106213231db1ca1ce54da8bd960..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -445,14 +445,9 @@ cc_library(
],
deps = [
"//tensorflow/compiler/jit:flags",
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
+ "//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
- "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 1de85004a51bea464f8f0166511402e5dd85ac14..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -18,86 +18,26 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/flags.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace dump_graph {
-namespace {
-
-struct NameCounts {
- mutex counts_mutex;
- std::unordered_map counts;
-};
-
-string MakeUniqueFilename(string name) {
- static NameCounts& instance = *new NameCounts;
-
- // Remove illegal characters from `name`.
- for (int i = 0; i < name.size(); ++i) {
- char ch = name[i];
- if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') {
- name[i] = '_';
- }
- }
-
- int count;
- {
- mutex_lock lock(instance.counts_mutex);
- count = instance.counts[name]++;
- }
-
- string filename = name;
- if (count > 0) {
- absl::StrAppend(&filename, "_", count);
- }
- absl::StrAppend(&filename, ".pbtxt");
- return filename;
-}
-
-string WriteTextProtoToUniqueFile(
- Env* env, const string& name, const char* proto_type,
- const ::tensorflow::protobuf::Message& proto) {
- const string& dirname = GetDumpGraphFlags()->tf_dump_graph_prefix;
- Status status = env->RecursivelyCreateDir(dirname);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create " << dirname << " for dumping "
- << proto_type << ": " << status;
- return "(unavailable)";
- }
- string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name));
- status = WriteTextProto(Env::Default(), filepath, proto);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
- << " : " << status;
- return "(unavailable)";
- }
- LOG(INFO) << "Dumped " << proto_type << " to " << filepath;
- return filepath;
-}
-
-} // anonymous namespace
-
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) {
- return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef",
- graph_def);
+ return tensorflow::DumpGraphDefToFile(
+ name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix);
}
string DumpGraphToFile(const string& name, Graph const& graph,
const FunctionLibraryDefinition* flib_def) {
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
- if (flib_def) {
- *graph_def.mutable_library() = flib_def->ToProto();
- }
- return DumpGraphDefToFile(name, graph_def);
+ return tensorflow::DumpGraphToFile(name, graph, flib_def,
+ GetDumpGraphFlags()->tf_dump_graph_prefix);
}
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) {
- return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef);
+ return tensorflow::DumpFunctionDefToFile(
+ name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix);
}
} // namespace dump_graph
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index c693e42d26712d55852f45c806215fc1f1b9a030..7ae96e1d484900e28e8c23c3bb2232401144ad82 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -640,7 +640,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If", library);
+ NodeDebugInfo debug_info((*merges_.begin())->def());
+ NodeDefBuilder builder(name(), "If", library, &debug_info);
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast(branch);
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index d85b4f5ae0cb9c7d2476158a5830f921742ae980..a18a4e92d62787051f6ab92e72ee8bf0d1060dca 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -1,16 +1,11 @@
+load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
+
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
-load("//tensorflow:tensorflow.bzl", "tf_copts")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
-load(
- "//third_party/mkl:build_defs.bzl",
- "if_mkl",
-)
-
tf_kernel_library(
name = "xla_ops",
srcs = [
@@ -121,15 +116,10 @@ tf_kernel_library(
":while_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:broadcast",
- "//tensorflow/compiler/tf2xla/lib:cholesky",
- "//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
- "//tensorflow/compiler/tf2xla/lib:triangular_solve",
"//tensorflow/compiler/tf2xla/lib:util",
- "//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
@@ -142,12 +132,16 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
+ "//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:sorting",
+ "//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -196,7 +190,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -216,7 +209,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:conv_ops",
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/matrix.h"
namespace tensorflow {
namespace {
@@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = BatchDot(ctx->Input(0), ctx->Input(1),
- /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
- /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
+ auto result =
+ xla::BatchDot(MaybeTransposeInMinorDims(
+ MaybeConjugate(ctx->Input(0), adj_x_), adj_x_),
+ MaybeTransposeInMinorDims(
+ MaybeConjugate(ctx->Input(1), adj_y_), adj_y_));
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644
--- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/cholesky.h"
namespace tensorflow {
namespace {
@@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- ctx->SetOutput(0, Cholesky(ctx->Input(0)));
+ ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index 641fefafb357f6ad10483c454600f3dadd4f8cb7..4124b258c7788e3850f07cbf4d53930784c635fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -392,23 +392,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
builder->GetShape(activations));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(gradients));
+ xla::XlaOp filter_backprop;
+
+ xla::Shape input_shape = activations_shape;
+ xla::Shape output_shape = out_backprop_shape;
+
+ TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
+
const xla::Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
-
// Reuse dimension computation logic from conv_grad_ops.cc.
ConvBackpropDimensions dims;
- TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
- type_string, attrs.num_spatial_dims, activations_shape,
- expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
- attrs.padding, attrs.data_format, &dims));
-
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_ops.h for details.
-
xla::ConvolutionDimensionNumbers dnums;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
// The activations (inputs) form the LHS of the convolution.
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
// For the gradient computation, we flip the roles of the batch and
@@ -420,29 +428,99 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
+ int64 total_spatial_size = 1;
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ total_spatial_size *= dims.input_size(i);
+ }
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
+ // We use this approach only for depthwise convolutions where feature counts
+ // are large but space dimensions are small. The conversion logic below
+ // assumes that the data format is NHWC, so we also check that here.
+ bool should_perform_depthwise_conv =
+ attrs.data_format == FORMAT_NHWC &&
+ (total_spatial_size < dims.in_depth) &&
+ filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
+
+ int64 num_spatial_dims =
+ attrs.num_spatial_dims + (should_perform_depthwise_conv ? 1 : 0);
+
+ std::vector> padding(num_spatial_dims);
+ std::vector rhs_dilation(num_spatial_dims);
+ std::vector window_strides(num_spatial_dims);
+ std::vector