diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b391279e479ade4ed5327728f19be8752e11507
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
@@ -0,0 +1,24 @@
+---
+name: TensorFlow Lite Op Request
+about: Use this template for reporting ops you are using or missing.
+
+---
+
+
+**System information**
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- TensorFlow installed from (source or binary):
+- TensorFlow version (or github SHA if from source):
+
+
+**Provide the text output from tflite_convert**
+
+```
+# Copy and paste here
+```
+
+Also, please include a link to a GraphDef or the model if possible.
+
+**Any other info / logs**
+
+Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
diff --git a/.gitignore b/.gitignore
index 57d84228cfd037325716b5faa56c17f7424fe713..90324058600bee46af56e49028977971848a80de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -24,7 +24,7 @@ Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
-/tensorflow/lite/downloads/**
+/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/gen/**
/tensorflow/lite/examples/ios/simple/data/*.txt
/tensorflow/lite/examples/ios/simple/data/*.tflite
diff --git a/CODEOWNERS b/CODEOWNERS
index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,7 +1,7 @@
# Where component owners are known, add them here.
/tenosrflow/core/debug @caisq
-/tensorflow/core/nccl/ @azaks @csigg
+/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
@@ -51,13 +51,13 @@
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden
/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
-/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl
+/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
/tensorflow/contrib/stateless/ @girving @alextp
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
-/tensorflow/contrib/tensorrt/ @aaroey
+/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
# NEED OWNER: /tensorflow/contrib/testing/
/tensorflow/contrib/timeseries/ @allenlavoie
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
diff --git a/README.md b/README.md
index 8af5370befbb090966a8b3af54d80c84a969aaa5..044174947a094d43a51f7140dd40ec0f17801d40 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,14 @@
|-----------------|
| [](https://www.tensorflow.org/api_docs/) |
-**TensorFlow** is an open source software library for numerical computation using
-data flow graphs. The graph nodes represent mathematical operations, while
+**TensorFlow** is an open source software library for numerical computation
+using data flow graphs. The graph nodes represent mathematical operations, while
the graph edges represent the multidimensional data arrays (tensors) that flow
-between them. This flexible architecture enables you to deploy computation to one
-or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
-code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit.
+between them. This flexible architecture enables you to deploy computation to
+one or more CPUs or GPUs in a desktop, server, or mobile device without
+rewriting code. TensorFlow also includes
+[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization
+toolkit.
TensorFlow was originally developed by researchers and engineers
working on the Google Brain team within Google's Machine Intelligence Research
@@ -111,22 +113,24 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
-**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA
+**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl)
## For more information
-* [TensorFlow Website](https://www.tensorflow.org)
-* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
-* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
-* [TensorFlow Twitter](https://twitter.com/tensorflow)
-* [TensorFlow Blog](https://medium.com/tensorflow)
-* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
-* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
-* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+
+* [TensorFlow Website](https://www.tensorflow.org)
+* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
+* [TensorFlow Blog](https://medium.com/tensorflow)
+* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
+* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
+* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
+* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
diff --git a/WORKSPACE b/WORKSPACE
index 0c7bc085b512b084b9470abe17326d7c119aa327..7cc08e0164a202581ad7ebbe107a9e19410e70e4 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,5 +1,7 @@
workspace(name = "org_tensorflow")
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
http_archive(
name = "io_bazel_rules_closure",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
@@ -57,9 +59,9 @@ android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
-new_http_archive(
+http_archive(
name = "inception_v1",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
@@ -67,9 +69,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_ssd",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
@@ -77,9 +79,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_multibox",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
@@ -87,9 +89,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "stylize",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
@@ -97,9 +99,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "speech_commands",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
diff --git a/configure.py b/configure.py
index 234561d94a46f57c4de5ca487360e2d5a3dfdb2f..6c905a0be3d685b5921dfbc5bddfbe6471a82625 100644
--- a/configure.py
+++ b/configure.py
@@ -238,6 +238,13 @@ def setup_python(environ_cp):
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
+ # If choosen python_lib_path is from a path specified in the PYTHONPATH
+ # variable, need to tell bazel to include PYTHONPATH
+ if environ_cp.get('PYTHONPATH'):
+ python_paths = environ_cp.get('PYTHONPATH').split(':')
+ if python_lib_path in python_paths:
+ write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
+
# Write tools/python_bin_path.sh
with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
@@ -445,11 +452,12 @@ def convert_version_to_int(version):
return int(version_str)
-def check_bazel_version(min_version):
- """Check installed bazel version is at least min_version.
+def check_bazel_version(min_version, max_version):
+ """Check installed bazel version is between min_version and max_version.
Args:
min_version: string for minimum bazel version.
+ max_version: string for maximum bazel version.
Returns:
The bazel version detected.
@@ -467,6 +475,7 @@ def check_bazel_version(min_version):
min_version_int = convert_version_to_int(min_version)
curr_version_int = convert_version_to_int(curr_version)
+ max_version_int = convert_version_to_int(max_version)
# Check if current bazel version can be detected properly.
if not curr_version_int:
@@ -480,6 +489,10 @@ def check_bazel_version(min_version):
print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version)
sys.exit(0)
+ if curr_version_int > max_version_int:
+ print('Please downgrade your bazel installation to version %s or lower to '
+ 'build TensorFlow!' % max_version)
+ sys.exit(0)
return curr_version
@@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
- if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
+ if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
@@ -1552,7 +1565,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.15.0')
+ check_bazel_version('0.15.0', '0.20.0')
reset_tf_configure_bazelrc()
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
@@ -1694,6 +1707,7 @@ def main():
config_info_line('nohdfs', 'Disable HDFS support.')
config_info_line('noignite', 'Disable Apacha Ignite support.')
config_info_line('nokafka', 'Disable Apache Kafka support.')
+ config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
if __name__ == '__main__':
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 859dc3b8d77be66e0f51e15d86188399273af23f..fd4b94202aad24a82abef8abd16431f61a8326f0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = (
TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
)
+# @unused
+TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = (
+ TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
+
# Config setting used when building for products
# which requires restricted licenses to be avoided.
config_setting(
@@ -213,31 +218,37 @@ config_setting(
#
config_setting(
name = "no_aws_support",
- define_values = {"no_aws_support": "false"},
+ define_values = {"no_aws_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_gcp_support",
- define_values = {"no_gcp_support": "false"},
+ define_values = {"no_gcp_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_hdfs_support",
- define_values = {"no_hdfs_support": "false"},
+ define_values = {"no_hdfs_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_ignite_support",
- define_values = {"no_ignite_support": "false"},
+ define_values = {"no_ignite_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_kafka_support",
- define_values = {"no_kafka_support": "false"},
+ define_values = {"no_kafka_support": "true"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "no_nccl_support",
+ define_values = {"no_nccl_support": "true"},
visibility = ["//visibility:public"],
)
@@ -350,7 +361,7 @@ package_group(
"-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
- "//tensorflow_estimator/...",
+ "//tensorflow_estimator/contrib/...",
"//tensorflow_fold/llgtm/...",
"//tensorflow_text/...",
"//third_party/py/tensor2tensor/...",
@@ -554,18 +565,24 @@ genrule(
}),
outs = ["__init__.py"],
cmd = select({
- "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
- "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
}),
)
gen_api_init_files(
name = "tf_python_api_gen_v1",
- srcs = ["api_template_v1.__init__.py"],
+ srcs = [
+ "api_template_v1.__init__.py",
+ "compat_template_v1.__init__.py",
+ ],
api_version = 1,
+ compat_api_versions = [1],
+ compat_init_templates = ["compat_template_v1.__init__.py"],
output_dir = "_api/v1/",
- output_files = TENSORFLOW_API_INIT_FILES_V1,
+ output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT,
output_package = "tensorflow._api.v1",
+ root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py",
)
@@ -581,6 +598,7 @@ gen_api_init_files(
output_dir = "_api/v2/",
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
+ root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
)
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 0d49756838505289a960a6cabeb7cab02fad995b..d81cf067eb07e88e2b8a86cf5643674235eb3f3b 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -21,8 +21,6 @@ from __future__ import print_function as _print_function
import os as _os
# pylint: disable=g-bad-import-order
-from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
@@ -30,16 +28,16 @@ _component_api_helper.package_hook(
# API IMPORTS PLACEHOLDER
-from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
-_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+# We're using bitwise, but there's nothing special about that.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
-# Calls to enable and disable features.
-enable_eager_execution() # pylint: disable=undefined-variable
+# Enable TF2 behaviors
+from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
+_compat.enable_v2_behavior()
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index b8db1b2144978e97bd32f62e643c2c4a7fcf1654..25df970ecab0757f23465ab19e7f45de0c759458 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -60,6 +60,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
}),
)
@@ -120,7 +121,8 @@ tf_cuda_library(
":c_api",
":c_api_internal",
"//tensorflow/c/eager:c_api",
- "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
+ "//tensorflow/c/eager:c_api_internal",
+ "//tensorflow/compiler/jit:flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -173,6 +175,60 @@ tf_cuda_library(
],
)
+tf_cuda_library(
+ name = "env",
+ srcs = [
+ "env.cc",
+ ],
+ hdrs = [
+ "env.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api",
+ ":tf_status_helper",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ "//tensorflow/core:platform_env",
+ "//tensorflow/core:lib",
+ ],
+ "//conditions:default": [
+ ":c_api",
+ ":tf_status_helper",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:platform_env",
+ "//tensorflow/core:lib",
+ ],
+ }) + [":c_api_internal"],
+)
+
+tf_cuda_library(
+ name = "kernels",
+ srcs = [
+ "kernels.cc",
+ ],
+ hdrs = [
+ "kernels.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api",
+ ":c_api_internal",
+ ":tf_status_helper",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ ":c_api",
+ ":c_api_internal",
+ ":tf_status_helper",
+ "//tensorflow/core:framework",
+ ],
+ }),
+)
+
# -----------------------------------------------------------------------------
# Tests
@@ -208,7 +264,10 @@ tf_cuda_cc_test(
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
- tags = ["noasan"],
+ tags = [
+ "no_oss", # http://b/119522529
+ "noasan",
+ ],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -237,7 +296,7 @@ tf_cuda_cc_test(
tf_cc_test(
name = "c_api_experimental_test",
- size = "small",
+ size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"],
linkopts = select({
@@ -248,8 +307,11 @@ tf_cc_test(
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
+ ":c_api",
":c_api_experimental",
":c_test_util",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@@ -300,6 +362,51 @@ tf_kernel_library(
alwayslink = 1,
)
+tf_cuda_cc_test(
+ name = "env_test",
+ size = "small",
+ srcs = ["env_test.cc"],
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
+ tags = ["noasan"],
+ # We must ensure that the dependencies can be dynamically linked since
+ # the shared library must be able to use core:framework.
+ # linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":c_api",
+ ":env",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "kernels_test",
+ size = "small",
+ srcs = ["kernels_test.cc"],
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
+ tags = ["noasan"],
+ # We must ensure that the dependencies can be dynamically linked since
+ # the shared library must be able to use core:framework.
+ # linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":c_api",
+ ":kernels",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Python API target
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index f13e8777dff164bcd8eedf46310ae846abd0c804..94d18eb8b04e3534be547aca5cfbb32da40ffbf6 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) {
namespace {
class TF_ManagedBuffer : public TensorBuffer {
public:
- void* data_;
- size_t len_;
- void (*deallocator_)(void* data, size_t len, void* arg);
- void* deallocator_arg_;
+ TF_ManagedBuffer(void* data, size_t len,
+ void (*deallocator)(void* data, size_t len, void* arg),
+ void* deallocator_arg)
+ : TensorBuffer(data),
+ len_(len),
+ deallocator_(deallocator),
+ deallocator_arg_(deallocator_arg) {}
+
+ const size_t len_;
+ void (*const deallocator_)(void* data, size_t len, void* arg);
+ void* const deallocator_arg_;
~TF_ManagedBuffer() override {
- (*deallocator_)(data_, len_, deallocator_arg_);
+ (*deallocator_)(data(), len_, deallocator_arg_);
}
- void* data() const override { return data_; }
size_t size() const override { return len_; }
TensorBuffer* root_buffer() override { return this; }
void FillAllocationDescription(AllocationDescription* proto) const override {
@@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
dimvec[i] = static_cast(dims[i]);
}
- TF_ManagedBuffer* buf = new TF_ManagedBuffer;
- buf->len_ = len;
+ TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) &&
reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
@@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
//
// Other types have the same representation, so copy only if it is safe to
// do so.
- buf->data_ = allocate_tensor("TF_NewTensor", len);
- std::memcpy(buf->data_, data, len);
- buf->deallocator_ = deallocate_buffer;
- buf->deallocator_arg_ = nullptr;
+ buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
+ deallocate_buffer, nullptr);
+ std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
- buf->data_ = data;
- buf->deallocator_ = deallocator;
- buf->deallocator_arg_ = deallocator_arg;
+ buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
}
+
TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
@@ -477,9 +480,9 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
CHECK_EQ(nelems, 0);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
- return TF_NewTensor(dtype, reinterpret_cast(dims.data()),
- shape.dims(), reinterpret_cast(&empty), 0,
- [](void*, size_t, void*) {}, nullptr);
+ return TF_NewTensor(
+ dtype, reinterpret_cast(dims.data()), shape.dims(),
+ reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr);
}
// Non-static for testing.
@@ -1592,18 +1595,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
break; \
}
- LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
- for (int i = 0; i < attr->list().s_size();
- ++i) { metadata.total_size += attr->list().s(i).size(); });
+ LIST_CASE(
+ s, TF_ATTR_STRING, metadata.total_size = 0;
+ for (int i = 0; i < attr->list().s_size();
+ ++i) { metadata.total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
- LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
- for (int i = 0; i < attr->list().shape_size(); ++i) {
- const auto& s = attr->list().shape(i);
- metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
- });
+ LIST_CASE(
+ shape, TF_ATTR_SHAPE, metadata.total_size = 0;
+ for (int i = 0; i < attr->list().shape_size(); ++i) {
+ const auto& s = attr->list().shape(i);
+ metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
+ });
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -91,7 +91,7 @@ extern "C" {
// --------------------------------------------------------------------------
// TF_Version returns a string describing version information of the
// TensorFlow library. TensorFlow using semantic versioning.
-TF_CAPI_EXPORT extern const char* TF_Version();
+TF_CAPI_EXPORT extern const char* TF_Version(void);
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
@@ -157,7 +157,7 @@ typedef enum TF_Code {
typedef struct TF_Status TF_Status;
// Return a new status object.
-TF_CAPI_EXPORT extern TF_Status* TF_NewStatus();
+TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
@@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto,
size_t proto_len);
// Useful for passing *out* a protobuf.
-TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer();
+TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void);
TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*);
@@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
typedef struct TF_SessionOptions TF_SessionOptions;
// Return a new options object.
-TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions();
+TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void);
// Set the target in TF_SessionOptions.options.
// target can be empty, a single entry, or a comma separated list of entries.
@@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*);
typedef struct TF_Graph TF_Graph;
// Return a new graph object.
-TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph();
+TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void);
// Destroy an options object. Graph will be deleted once no more
// TFSession's are referencing it.
@@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph,
// TF_GraphImportGraphDef.
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
-TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions();
+TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(
+ void);
TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_ImportGraphDefOptions* opts);
@@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
//
// The data in the buffer will be the serialized OpList proto for ops registered
// in this address space.
-TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList();
+TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void);
// TF_ApiDefMap encapsulates a collection of API definitions for an operation.
//
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index fabe2fa0f60bc8baafa7f83802da74bb7ab93c6d..38e29aa74a90f4e85d1369b6928a5a58c531b2da 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -15,13 +15,18 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
- tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
- tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ tensorflow::MarkForCompilationPassFlags* flags =
+ tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
@@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
- tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
- tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ tensorflow::MarkForCompilationPassFlags* flags =
+ tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
@@ -6525,7 +6530,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/cycle_length"
+ name: "ExperimentalParallelInterleaveDataset/cycle_length"
op: "Const"
attr {
key: "dtype"
@@ -6546,7 +6551,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/block_length"
+ name: "ExperimentalParallelInterleaveDataset/block_length"
op: "Const"
attr {
key: "dtype"
@@ -6567,7 +6572,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/sloppy"
+ name: "ExperimentalParallelInterleaveDataset/sloppy"
op: "Const"
attr {
key: "dtype"
@@ -6588,7 +6593,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/buffer_output_elements"
+ name: "ExperimentalParallelInterleaveDataset/buffer_output_elements"
op: "Const"
attr {
key: "dtype"
@@ -6609,7 +6614,7 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset/prefetch_input_elements"
+ name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements"
op: "Const"
attr {
key: "dtype"
@@ -6630,14 +6635,14 @@ library {
}
}
node_def {
- name: "ParallelInterleaveDataset"
- op: "ParallelInterleaveDataset"
+ name: "ExperimentalParallelInterleaveDataset"
+ op: "ExperimentalParallelInterleaveDataset"
input: "RepeatDataset:handle:0"
- input: "ParallelInterleaveDataset/cycle_length:output:0"
- input: "ParallelInterleaveDataset/block_length:output:0"
- input: "ParallelInterleaveDataset/sloppy:output:0"
- input: "ParallelInterleaveDataset/buffer_output_elements:output:0"
- input: "ParallelInterleaveDataset/prefetch_input_elements:output:0"
+ input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0"
+ input: "ExperimentalParallelInterleaveDataset/block_length:output:0"
+ input: "ExperimentalParallelInterleaveDataset/sloppy:output:0"
+ input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0"
+ input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0"
attr {
key: "Targuments"
value {
@@ -6737,7 +6742,7 @@ library {
node_def {
name: "ShuffleDataset_2"
op: "ShuffleDataset"
- input: "ParallelInterleaveDataset:handle:0"
+ input: "ExperimentalParallelInterleaveDataset:handle:0"
input: "ShuffleDataset_2/buffer_size_1:output:0"
input: "ShuffleDataset_2/seed_2:output:0"
input: "ShuffleDataset_2/seed2_2:output:0"
@@ -8739,14 +8744,65 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
TF_DeleteStatus(status);
}
-TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
- const char* errMsg) {
+struct TFE_ExecuteOpNotification {
+ TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
+ tensorflow::Notification n;
+ std::unique_ptr thread;
+ std::unique_ptr status;
+};
+
+TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
+ TFE_TensorHandle** retvals,
+ int* num_retvals,
+ TF_Status* status) {
+ TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
+
+ n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
+ tensorflow::ThreadOptions(), "ExecuteOpThread",
+ [op, retvals, num_retvals, n]() {
+ TFE_Execute(op, retvals, num_retvals, n->status.get());
+ n->n.Notify();
+ }));
+
+ return n;
+}
+
+void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status) {
+ if (notification == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification is a nullptr.");
+
+ return;
+ }
+ if (notification->thread == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification didn't start a thread correctly. Cleaning up "
+ "this notification. Please re-execute the operation to get a new "
+ "notification.");
+
+ delete notification;
+ return;
+ }
+
+ notification->n.WaitForNotification();
+
+ status->status = notification->status->status;
+
+ delete notification;
+}
+
+void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
// This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder;
+ // The string buffers to make sure that any `attr_name` we pass into
+ // `builder->Set()` will outlive the subsequent
+ // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
+ std::set attr_names;
};
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
@@ -8757,13 +8813,15 @@ void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
TF_DataType value) {
- builder->Set(attr_name, static_cast(value));
+ auto iter = builder->attr_names.insert(attr_name).first;
+ builder->Set((*iter).c_str(), static_cast(value));
}
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) {
+ auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(
- attr_name,
+ (*iter).c_str(),
tensorflow::gtl::ArraySlice(
reinterpret_cast(values), num_values));
}
@@ -8800,3 +8858,31 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
// The returned string is owned by OpRegistry, so liveness is not a concern.
return input_arg.number_attr().c_str();
}
+
+int TF_OpIsStateful(const char* op_type, TF_Status* status) {
+ const tensorflow::OpRegistrationData* op_reg_data;
+ status->status =
+ tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
+ if (!status->status.ok()) {
+ return 0;
+ }
+ return op_reg_data->op_def.is_stateful();
+}
+
+void TF_InitMain(const char* usage, int* argc, char*** argv) {
+ tensorflow::port::InitMain(usage, argc, argv);
+}
+
+int TF_PickUnusedPortOrDie() {
+ return tensorflow::internal::PickUnusedPortOrDie();
+}
+
+TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg,
+ void* data, size_t len) {
+ auto dtype = static_cast(dtype_arg);
+ DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
+
+ tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
+ std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
+ return new TFE_TensorHandle(tensor, nullptr, nullptr);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 6639b0be72bdf81d0e3c806770364d7bc5082ad2..3e3a485eb763b871b0551414c4ef04746b2ed9a3 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
+typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
+
+// Allows invoking a kernel asynchronously, and explicitly returns a
+// notification that can be waited upon. This always executes the kernel in a
+// new thread.
+// 1. `retvals` and `num_retvals` can only be consumed after
+// `TFE_ExecuteOp` returns successfully. They shouldn't be used
+// if the return is unsuccessful
+// 2. These new APIs cannot be used together with the TFE context level async
+// support.
+TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
+ TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+ TF_Status* status);
+
+// Waits to complete the op execution, and cleans up the notification.
+// Errors reported by op execution are set in `status`.
+TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status);
+
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
@@ -209,6 +228,24 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
const char* op_name, int input_index, TF_Status* status);
+// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined
+// if the status is not ok.
+TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type,
+ TF_Status* status);
+
+// Platform specific initialization routine. Very few platforms actually require
+// this to be called.
+TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
+
+// Platform-specific implementation to return an unused port. (This should used
+// in tests only.)
+TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void);
+
+// Fast path method that makes constructing a single scalar tensor require less
+// overhead and copies.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
+ TF_DataType dtype, void* scalar, size_t len);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index c6effd39697e0397278770b53e98508074f99862..daa7701b7fe7e8ce757b6504329cf6434ad39778 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -162,5 +164,137 @@ protocol: "grpc"
TF_DeleteStatus(status);
}
+TEST(CAPI_EXPERIMENTAL, IsStateful) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ int assign = TF_OpIsStateful("AssignAddVariableOp", status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ EXPECT_EQ(assign, 1);
+ int id = TF_OpIsStateful("Identity", status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ EXPECT_EQ(id, 0);
+}
+
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ TFE_Op* matmul_op = MatMulOp(ctx, m, m);
+
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
+
+ auto* r =
+ TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(r, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteOp(matmul_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
+// Perform a send/recv test. Recv blocks, so they need to be executed
+// asynchronously.
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ // Build a send op.
+ TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(send_op, m, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ string tensor_name = "Tensor";
+ TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
+ TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
+ string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(send_op, "client_terminated", true);
+
+ // Build a recv op.
+ TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
+ TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
+ TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(recv_op, "client_terminated", true);
+
+ TFE_TensorHandle* send_retvals;
+ int send_num_retvals = 0;
+ auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
+ &send_num_retvals, status);
+
+ TFE_TensorHandle* recv_retvals[1] = {nullptr};
+ int recv_num_retvals = 1;
+ auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
+ &recv_num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1, product[0]);
+ EXPECT_EQ(2, product[1]);
+ EXPECT_EQ(3, product[2]);
+ EXPECT_EQ(4, product[3]);
+
+ TFE_DeleteOp(send_op);
+ TFE_DeleteOp(recv_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(recv_retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index f68f8a3e90a971b5e4a024feaf26ba498afc48da..28b9f8df9c873ee394eb6a241dd9ac06ba6c8796 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -392,26 +392,26 @@ Status ProcessInputs(
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
- const Node& node = inputs[i].oper->node;
+ Node* node = &inputs[i].oper->node;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- fn_body->graph.IsValidOutputTensor(&node, idx),
+ fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
- input_tensors->emplace_back(&node, idx);
+ input_tensors->emplace_back(node, idx);
- const auto& iter = input_nodes->find(&node);
+ const auto& iter = input_nodes->find(node);
if (iter == input_nodes->end()) {
- input_nodes->insert({&node, {idx}});
+ input_nodes->insert({node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
- return InvalidArgument("TF_Output ", node.name(), ":", idx,
+ return InvalidArgument("TF_Output ", node->name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
@@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
- const Node& node = outputs[i].oper->node;
+ Node* node = &outputs[i].oper->node;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- fn_body->graph.IsValidOutputTensor(&node, idx),
+ fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while creating function '",
fn_name, "'");
- output_tensors->emplace_back(&node, idx);
+ output_tensors->emplace_back(node, idx);
}
return Status::OK();
}
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index ba3d8533db7623b8fa7fdf35093abcd1450776b1..c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -50,6 +50,7 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
+ "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@@ -143,6 +144,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 408277468d7beb23d1b2ab7f9bbccac16332e55a..027d752f420238da867cb9d8c116640e1730caaa 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -21,9 +21,11 @@ limitations under the License.
#include
#include
+#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/core/platform/host_info.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
const std::vector& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr* device_mgr) {
- std::vector remote_devices;
+ std::vector> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
@@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
status = s;
if (s.ok()) {
for (tensorflow::Device* d : *devices) {
- remote_devices.push_back(d);
+ remote_devices.emplace_back(d);
}
}
n.Notify();
@@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification();
}
std::unique_ptr remote_device_mgr(
- new tensorflow::DeviceMgr(remote_devices));
+ new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
@@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- std::vector devices;
+ std::vector> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr device_mgr(
- new tensorflow::DeviceMgr(devices));
+ new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
@@ -409,6 +411,18 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
: d->name().c_str();
}
+const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
+ TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+ tensorflow::Device* d = h->handle->device();
+ return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+ : d->name().c_str();
+}
+
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
@@ -458,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
- status->status = tensorflow::AttrTypeMapForOp(name, &types);
- if (status->status.ok()) return new TFE_Op(ctx, name, types);
- if (TF_GetCode(status) == TF_NOT_FOUND) {
- if (ctx->context.FindFunctionByName(name)) {
- status->status = tensorflow::Status::OK();
- return new TFE_Op(ctx, name, nullptr);
+ bool is_function = false;
+ status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
+ if (status->status.ok()) {
+ if (is_function && !ctx->context.FindFunctionByName(name)) {
+ status->status = tensorflow::errors::NotFound(
+ "'", name,
+ "' is neither a type of a primitive operation nor a name "
+ "of a function registered in binary running on ",
+ tensorflow::port::Hostname(),
+ ". Make sure the operation or function is "
+ "registered in the binary running in this process.");
+ return nullptr;
}
+ return new TFE_Op(ctx, name, is_function, types);
}
return nullptr;
}
@@ -497,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
- if (op->operation.is_function()) {
- status->status = tensorflow::errors::Unimplemented(
- "TODO(apassos): Support for attributes for TensorFlow functions is not "
- "ready yet.");
- return TF_ATTR_INT; // The compiler requires that we return something.
- }
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
return ret;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index b2454d872207e26feb3764671474a5d87c01f84d..f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -48,7 +48,7 @@ extern "C" {
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
-TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
+TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void);
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
@@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
+
+// Returns the device of the operation that produced `h`.
+// If `h` was produced by a copy, returns the destination device of
+// the copy. Note that returned device name is not always the device
+// holding the tensor handle's memory. If you want the latter, use
+// TFE_TensorHandleBackingDeviceName.
+// This function will block till the operation that produces `h` has completed.
+//
+// Device on which the kernel of the operation that produced `h` ran.
+//
+// If `h` was produced by a copy, returns the destination device of
+// the copy.
+//
+// Note that returned device name is not always the device that owns the memory
+// that backs the tensor handle. For the latter see
+// TFE_TensorHandleBackingDeviceName.
+//
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Returns the name of the device in whose memory `h` resides.
+//
+// This function will block till the operation that produces `h` has completed.
+TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
+ TFE_TensorHandle* h, TF_Status* status);
+
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
// the error and a nullptr is returned.
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index fa1b22e3af487b19b8b7885b7c3740b6249c73eb..67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -93,10 +93,9 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
- // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
- // primitive operation.
- TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
- : operation(&ctx->context, op, t) {}
+ TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
+ const tensorflow::AttrTypeMap* t)
+ : operation(&ctx->context, op, is_function, t) {}
tensorflow::EagerOperation operation;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 55331022b9dbd0696928fa44430f340f371432ac..6b39b79ee82f9c7baaf856e573a42b7da65691e5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include
+#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
+ bool has_gpu0 = false;
+ bool has_gpu1 = false;
+ for (int i = 0; i < num_devices; ++i) {
+ const char* dev = TF_DeviceListName(devices, i, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ string device_name(dev);
+ if (device_name.find("GPU:0") != string::npos) {
+ has_gpu0 = true;
+ }
+ if (device_name.find("GPU:1") != string::npos) {
+ has_gpu1 = true;
+ }
+ }
const char* kCPUDevice = "CPU:0";
- if (num_devices < 3) {
+ if (!has_gpu0 || !has_gpu1) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
@@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) {
TF_SetStatus(status.get(), TF_OK, "");
+ device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(device_name, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
@@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) {
string(TF_Message(status.get())));
}
+TEST(CAPI, TensorHandleDevices) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
+ const char* backing_device_name =
+ TFE_TensorHandleBackingDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ // Disable the test if no GPU is present.
+ string gpu_device_name;
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
+ TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
+ hcpu, ctx, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_Op* shape_op = ShapeOp(ctx, hgpu);
+ TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // .device of shape is GPU since the op is executed on GPU
+ device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
+
+ // .backing_device of shape is CPU since the tensor is backed by CPU
+ backing_device_name =
+ TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ TFE_DeleteOp(shape_op);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
+ }
+
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
+}
+
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 008f088c2dcdd7d9114103516a4702e47a55c6de..bd38127d50c171af801dd1b937acefdba491b4a6 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
return op;
}
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
TFE_TensorHandle* TestAxisTensorHandle() {
int64_t dims[] = {1};
int data[] = {1};
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
index 474cae67c89249af3a62707f0db00ba458ca8f31..75ef9459e93b4f2ed471c423a34565594efc1714 100644
--- a/tensorflow/c/eager/c_api_test_util.h
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
+// Return a shape op fetching the shape of `a`.
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
+
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle();
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 5ba55a203ff70cc64c07e96b5a869a1f11c9334e..5c11f51e8749de84547ae873f5f55ebd42bc4b3d 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -141,8 +141,9 @@ class GradientTape {
// null. The result is populated with one tensor per target element.
Status ComputeGradient(
const VSpace& vspace,
- gtl::ArraySlice target_tensor_ids,
- gtl::ArraySlice source_tensor_id,
+ const gtl::ArraySlice target_tensor_ids,
+ const gtl::ArraySlice source_tensor_ids,
+ const gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients,
std::vector* result);
@@ -396,6 +397,7 @@ template
Status InitialGradients(
const VSpace& vspace,
gtl::ArraySlice target_tensor_ids,
+ gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients, const TensorTape& tensor_tape,
const OpTape& op_tape,
gtl::FlatMap>* result) {
@@ -425,8 +427,13 @@ Status InitialGradients(
"none of operations outputs match expected tensor");
}
} else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ // This target tensor was not generated by any operation recorded on
+ // the tape, so no gradient needs to be computed from it unless this
+ // target is also a source.
+ auto source_tensor = sources_that_are_targets.find(id);
+ if (source_tensor != sources_that_are_targets.end()) {
+ (*result)[id].push_back(vspace.Ones(source_tensor->second));
+ }
}
} else {
(*result)[id].push_back(output_gradients[i]);
@@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
template
Status GradientTape::ComputeGradient(
const VSpace& vspace,
- gtl::ArraySlice target_tensor_ids,
- gtl::ArraySlice source_tensor_ids,
+ const gtl::ArraySlice target_tensor_ids,
+ const gtl::ArraySlice source_tensor_ids,
+ const gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients,
std::vector* result) {
gtl::FlatSet sources_set(source_tensor_ids.begin(),
@@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient(
std::vector op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap> gradients;
- Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
+ Status s = InitialGradients(vspace, target_tensor_ids,
+ sources_that_are_targets, output_gradients,
tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc
new file mode 100644
index 0000000000000000000000000000000000000000..07b9e8b940c55caf62ae0b81b884bf313d335459
--- /dev/null
+++ b/tensorflow/c/env.cc
@@ -0,0 +1,161 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/env.h"
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+struct TF_StringStream {
+ std::vector<::tensorflow::string>* list;
+ size_t position;
+};
+
+void TF_CreateDir(const char* dirname, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->CreateDir(dirname));
+}
+
+void TF_DeleteDir(const char* dirname, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteDir(dirname));
+}
+
+void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count,
+ int64_t* undeleted_dir_count, TF_Status* status) {
+ ::tensorflow::int64 f, d;
+
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d));
+ *undeleted_file_count = f;
+ *undeleted_dir_count = d;
+}
+
+void TF_FileStat(const char* filename, TF_FileStatistics* stats,
+ TF_Status* status) {
+ ::tensorflow::FileStatistics cc_stats;
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Status s =
+ ::tensorflow::Env::Default()->Stat(filename, &cc_stats);
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+ if (s.ok()) {
+ stats->length = cc_stats.length;
+ stats->mtime_nsec = cc_stats.mtime_nsec;
+ stats->is_directory = cc_stats.is_directory;
+ }
+}
+
+void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle,
+ TF_Status* status) {
+ std::unique_ptr<::tensorflow::WritableFile> f;
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Status s =
+ ::tensorflow::Env::Default()->NewWritableFile(filename, &f);
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+
+ if (s.ok()) {
+ *handle = reinterpret_cast(f.release());
+ }
+}
+
+void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close());
+ delete cc_file;
+}
+
+void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync());
+}
+
+void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush());
+}
+
+void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data,
+ size_t length, TF_Status* status) {
+ auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, cc_file->Append(::tensorflow::StringPiece{data, length}));
+}
+
+void TF_DeleteFile(const char* filename, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->DeleteFile(filename));
+}
+
+bool TF_StringStreamNext(TF_StringStream* list, const char** result) {
+ if (list->position >= list->list->size()) {
+ *result = nullptr;
+ return false;
+ }
+
+ *result = list->list->at(list->position++).c_str();
+ return true;
+}
+
+void TF_StringStreamDone(TF_StringStream* list) {
+ delete list->list;
+ delete list;
+}
+TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) {
+ auto* children = new std::vector<::tensorflow::string>;
+
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(
+ status, ::tensorflow::Env::Default()->GetChildren(dirname, children));
+
+ auto* list = new TF_StringStream;
+ list->list = children;
+ list->position = 0;
+ return list;
+}
+
+TF_StringStream* TF_GetLocalTempDirectories() {
+ auto* tmpdirs = new std::vector<::tensorflow::string>;
+
+ ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs);
+
+ auto* list = new TF_StringStream;
+ list->list = tmpdirs;
+ list->position = 0;
+ return list;
+}
+
+TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
+ return ::tensorflow::Env::Default()->NowNanos();
+}
+
+// Returns the number of microseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) {
+ return ::tensorflow::Env::Default()->NowMicros();
+}
+
+// Returns the number of seconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
+ return ::tensorflow::Env::Default()->NowSeconds();
+}
diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h
new file mode 100644
index 0000000000000000000000000000000000000000..9d27c5da37735042c7476b591e57486dbde33152
--- /dev/null
+++ b/tensorflow/c/env.h
@@ -0,0 +1,157 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_ENV_H_
+#define TENSORFLOW_C_ENV_H_
+
+#include "tensorflow/c/c_api.h"
+
+// --------------------------------------------------------------------------
+// C API for tensorflow::Env.
+
+struct TF_WritableFileHandle;
+struct TF_StringStream;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct TF_FileStatistics {
+ // The length of the file in bytes.
+ int64_t length;
+ // The last modified time in nanoseconds.
+ int64_t mtime_nsec;
+ // Whether the name refers to a directory.
+ bool is_directory;
+} TF_FileStatistics;
+
+// Creates the specified directory. Typical status code are:
+// * TF_OK - successfully created the directory
+// * TF_ALREADY_EXISTS - directory already exists
+// * TF_PERMISSION_DENIED - dirname is not writable
+TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status);
+
+// Deletes the specified directory. Typical status codes are:
+// * TF_OK - successfully deleted the directory
+// * TF_FAILED_PRECONDITION - the directory is not empty
+TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status);
+
+// Deletes the specified directory and all subdirectories and files underneath
+// it. This is accomplished by traversing the directory tree rooted at dirname
+// and deleting entries as they are encountered.
+//
+// If dirname itself is not readable or does not exist, *undeleted_dir_count is
+// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g.
+// TF_NOT_FOUND) is returned.
+//
+// If dirname and all its descendants were successfully deleted, TF_OK is
+// returned and both error counters are set to zero.
+//
+// Otherwise, while traversing the tree, undeleted_file_count and
+// undeleted_dir_count are updated if an entry of the corresponding type could
+// not be deleted. The returned error status represents the reason that any one
+// of these entries could not be deleted.
+//
+// Typical status codes:
+// * TF_OK - dirname exists and we were able to delete everything underneath
+// * TF_NOT_FOUND - dirname doesn't exist
+// * TF_PERMISSION_DENIED - dirname or some descendant is not writable
+// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not
+// implemented
+TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname,
+ int64_t* undeleted_file_count,
+ int64_t* undeleted_dir_count,
+ TF_Status* status);
+
+// Obtains statistics for the given path. If status is TF_OK, *stats is
+// updated, otherwise it is not touched.
+TF_CAPI_EXPORT extern void TF_FileStat(const char* filename,
+ TF_FileStatistics* stats,
+ TF_Status* status);
+
+// Creates or truncates the given filename and returns a handle to be used for
+// appending data to the file. If status is TF_OK, *handle is updated and the
+// caller is responsible for freeing it (see TF_CloseWritableFile).
+TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename,
+ TF_WritableFileHandle** handle,
+ TF_Status* status);
+
+// Closes the given handle and frees its memory. If there was a problem closing
+// the file, it is indicated by status. Memory is freed in any case.
+TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Syncs content of the handle to the filesystem. Blocks waiting for the
+// filesystem to indicate that the content has been persisted.
+TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Flush local buffers to the filesystem. If the process terminates after a
+// successful flush, the contents may still be persisted, since the underlying
+// filesystem may eventually flush the contents. If the OS or machine crashes
+// after a successful flush, the contents may or may not be persisted, depending
+// on the implementation.
+TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle,
+ TF_Status* status);
+
+// Appends the given bytes to the file. Any failure to do so is indicated in
+// status.
+TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle,
+ const char* data,
+ size_t length,
+ TF_Status* status);
+
+// Deletes the named file and indicates whether successful in *status.
+TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename,
+ TF_Status* status);
+
+// Retrieves the next item from the given TF_StringStream and places a pointer
+// to it in *result. If no more items are in the list, *result is set to NULL
+// and false is returned.
+//
+// Ownership of the items retrieved with this function remains with the library.
+// Item points are invalidated after a call to TF_StringStreamDone.
+TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list,
+ const char** result);
+
+// Frees the resources associated with given string list. All pointers returned
+// by TF_StringStreamNext are invalid after this call.
+TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list);
+
+// Retrieves the list of children of the given directory. You can iterate
+// through the list with TF_StringStreamNext. The caller is responsible for
+// freeing the list (see TF_StringStreamDone).
+TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
+ TF_Status* status);
+
+// Retrieves a list of directory names on the local machine that may be used for
+// temporary storage. You can iterate through the list with TF_StringStreamNext.
+// The caller is responsible for freeing the list (see TF_StringStreamDone).
+TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
+
+// Returns the number of nanoseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);
+
+// Returns the number of microseconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void);
+
+// Returns the number of seconds since the Unix epoch.
+TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // TENSORFLOW_C_ENV_H_
diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e2206c6befd2167346c64032940d6e8c631e4a3e
--- /dev/null
+++ b/tensorflow/c/env_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/env.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
+
+TEST(TestEnv, TestDirHandling) {
+ TF_StringStream* tempdirs = TF_GetLocalTempDirectories();
+ const char* tempdir;
+ bool found = false;
+ while (TF_StringStreamNext(tempdirs, &tempdir)) {
+ found = true;
+
+ TF_Status* s = TF_NewStatus();
+
+ ::tensorflow::string dirpath =
+ ::tensorflow::io::JoinPath(tempdir, "somedir");
+ TF_CreateDir(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": "
+ << TF_Message(s);
+
+ ::tensorflow::string filepath =
+ ::tensorflow::io::JoinPath(dirpath, "somefile.txt");
+ TF_WritableFileHandle* handle;
+ TF_NewWritableFile(filepath.c_str(), &handle, s);
+ ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": "
+ << TF_Message(s);
+
+ const char* data = "Hello, world!\n";
+ TF_AppendWritableFile(handle, data, strlen(data), s);
+ ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at "
+ << filepath << ": " << TF_Message(s);
+
+ TF_CloseWritableFile(handle, s);
+ ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to "
+ << filepath << ": " << TF_Message(s);
+
+ TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath;
+ const char* childpath;
+ ASSERT_TRUE(TF_StringStreamNext(children, &childpath));
+ ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt");
+ // There should only be one file in this directory.
+ ASSERT_FALSE(TF_StringStreamNext(children, &childpath));
+ ASSERT_EQ(childpath, nullptr);
+ TF_StringStreamDone(children);
+
+ TF_FileStatistics stats;
+ TF_FileStat(filepath.c_str(), &stats, s);
+ ASSERT_EQ(stats.length, strlen(data));
+ ASSERT_FALSE(stats.is_directory);
+ ASSERT_GT(stats.mtime_nsec, 0);
+
+ // Trying to delete a non-empty directory should fail.
+ TF_DeleteDir(dirpath.c_str(), s);
+ ASSERT_NE(TF_OK, TF_GetCode(s))
+ << "TF_DeleteDir unexpectedly succeeded with a non-empty directory "
+ << dirpath;
+
+ TF_DeleteFile(filepath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": "
+ << TF_Message(s);
+
+ // Now deleting the directory should work.
+ TF_DeleteDir(dirpath.c_str(), s);
+ ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": "
+ << TF_Message(s);
+
+ TF_DeleteStatus(s);
+ break;
+ }
+
+ ASSERT_TRUE(found) << "expected at least one temp dir";
+
+ TF_StringStreamDone(tempdirs);
+}
+
+TEST(TestEnv, TestTimeFunctions) {
+ ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000
+ ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
+ ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
+}
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
new file mode 100644
index 0000000000000000000000000000000000000000..2a4eaecb6cf2740a522b1e849d1306ebde6c4577
--- /dev/null
+++ b/tensorflow/c/kernels.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/kernels.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+// This file forms the basis of a stable ABI for third-party kernel
+// implementations. It is crucial that changes to this file are made cautiously
+// and with a focus on maintaining both source and binary compatibility.
+
+struct TF_KernelBuilder {
+ ::tensorflow::KernelDefBuilder* cc_builder;
+
+ void* (*create_function)(TF_OpKernelConstruction*);
+ void (*compute_function)(void*, TF_OpKernelContext*);
+ void (*delete_function)(void*);
+};
+
+TF_KernelBuilder* TF_NewKernelBuilder(
+ const char* op_name, const char* device_name,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*)) {
+ TF_KernelBuilder* result = new TF_KernelBuilder;
+ result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
+ result->cc_builder->Device(device_name);
+ result->create_function = create_func;
+ result->compute_function = compute_func;
+ result->delete_function = delete_func;
+ return result;
+}
+
+void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
+ DCHECK_NE(builder, nullptr);
+ delete builder->cc_builder;
+ delete builder;
+}
+
+namespace tensorflow {
+namespace {
+
+// An OpKernel whose methods delegate to C function pointers.
+class COpKernel : public OpKernel {
+ public:
+ explicit COpKernel(OpKernelConstruction* ctx,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*))
+ : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
+ if (create_func != nullptr) {
+ c_kernel_ =
+ (*create_func)(reinterpret_cast(ctx));
+ } else {
+ c_kernel_ = nullptr;
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ (*compute_func_)(c_kernel_, reinterpret_cast(ctx));
+ }
+
+ ~COpKernel() override {
+ if (delete_func_ != nullptr) {
+ (*delete_func_)(c_kernel_);
+ }
+ }
+
+ private:
+ void (*compute_func_)(void*, TF_OpKernelContext* context);
+ void (*delete_func_)(void*);
+ void* c_kernel_;
+};
+
+// A KernelFactory that returns COpKernel instances.
+class KernelBuilderFactory
+ : public ::tensorflow::kernel_factory::OpKernelFactory {
+ public:
+ explicit KernelBuilderFactory(TF_KernelBuilder* builder)
+ : builder_(builder) {}
+ ::tensorflow::OpKernel* Create(
+ ::tensorflow::OpKernelConstruction* context) override {
+ return new ::tensorflow::COpKernel(context, builder_->create_function,
+ builder_->compute_function,
+ builder_->delete_function);
+ }
+ ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
+
+ private:
+ TF_KernelBuilder* builder_;
+};
+} // namespace
+} // namespace tensorflow
+
+void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
+ TF_Status* status) {
+ using tensorflow::register_kernel::Name;
+
+ tensorflow::kernel_factory::OpKernelRegistrar(
+ builder->cc_builder->Build(), name,
+ absl::make_unique(builder));
+
+ TF_SetStatus(status, TF_OK, "");
+}
+
+int TF_NumInputs(TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ return cc_ctx->num_inputs();
+}
+
+int TF_NumOutputs(TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ return cc_ctx->num_outputs();
+}
+
+void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ if (i < 0 || i >= cc_ctx->num_inputs()) {
+ TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
+ return;
+ }
+ const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
+ TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
+ if (TF_GetCode(status) == TF_OK) {
+ *tensor = result;
+ }
+}
+
+void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ if (i < 0 || i >= cc_ctx->num_inputs()) {
+ TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
+ return;
+ }
+ ::tensorflow::Tensor cc_tensor;
+ ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
+ TF_SetStatus(status, TF_OK, "");
+ ::tensorflow::Set_TF_Status_from_Status(status, s);
+ if (s.ok()) {
+ cc_ctx->set_output(i, cc_tensor);
+ }
+}
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..1a91aa184f11ac8e45b38a1d106c7b445747a7c1
--- /dev/null
+++ b/tensorflow/c/kernels.h
@@ -0,0 +1,118 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_KERNELS_H_
+#define TENSORFLOW_C_KERNELS_H_
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// --------------------------------------------------------------------------
+// C API for TensorFlow Kernels.
+//
+// This API allows developers to register custom kernel implementations for
+// TensorFlow.
+//
+// See c_api.h header comments for a discussion about API conventions.
+//
+// Users wishing to extend TensorFlow with new kernels will call
+// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
+// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
+// kernels when necessary.
+
+struct TF_KernelBuilder;
+struct TF_OpKernelConstruction;
+struct TF_OpKernelContext;
+
+// Allocates a new kernel builder and returns a pointer to it.
+//
+// If non-null, TensorFlow will call create_func when it needs to instantiate
+// the kernel. The pointer returned by create_func will be passed to
+// compute_func and delete_func, thereby functioning as a "this" pointer for
+// referring to kernel instances.
+//
+// The TF_OpKernelConstruction pointer passed to create_func is owned by
+// TensorFlow and will be deleted once create_func returns. It must not be used
+// after this.
+//
+// When TensorFlow needs to perform a computation with this kernel, it will
+// call compute_func. This function will receive the pointer returned by
+// create_func (or null if no create_func was provided), along with the inputs
+// to the computation.
+//
+// The TF_OpKernelContext pointer received by compute_func is owned by
+// TensorFlow and will be deleted once compute_func returns. It must not be used
+// after this.
+//
+// Finally, when TensorFlow no longer needs the kernel, it will call
+// delete_func if one is provided. This function will receive the pointer
+// returned in `create_func` or nullptr if no `create_func` was provided.
+//
+// The caller should pass the result of this function to
+// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
+// some reason, the kernel builder will not be registered, the caller should
+// delete it with TF_DeleteKernelBuilder.
+TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
+ const char* op_name, const char* device_name,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*));
+
+// Register the given kernel builder with the TensorFlow runtime. If
+// registration fails, the given status will be populated.
+//
+// This call takes ownership of the `builder` pointer.
+TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
+ TF_KernelBuilder* builder,
+ TF_Status* status);
+
+// Deletes the given TF_KernelBuilder. This should be called only if the kernel
+// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
+TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
+
+// --------------------------------------------------------------------------
+// OpKernelContext routines
+
+// TF_NumInputs returns the number of inputs available in ctx.
+TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
+
+// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
+// kernel.
+TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
+
+// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
+// populated and its ownership is passed to the caller. In any other case,
+// *tensor is not modified.
+//
+// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
+TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
+ TF_Tensor** tensor, TF_Status* status);
+
+// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
+// TF_OK, ctx is left unmodified.
+//
+// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
+TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i,
+ const TF_Tensor* tensor,
+ TF_Status* status);
+
+#ifdef __cplusplus
+} /* end extern "C" */
+#endif
+
+#endif // TENSORFLOW_C_KERNELS_H_
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e659ee3c3d258a626ccf03a782ec031b5a703a48
--- /dev/null
+++ b/tensorflow/c/kernels_test.cc
@@ -0,0 +1,203 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/kernels.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def.pb_text.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+struct MyCustomKernel {
+ bool created;
+ bool compute_called;
+};
+
+static bool delete_called = false;
+
+static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
+ struct MyCustomKernel* s = new struct MyCustomKernel;
+ s->created = true;
+ s->compute_called = false;
+ return s;
+}
+
+static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
+ struct MyCustomKernel* s = static_cast(kernel);
+ s->compute_called = true;
+}
+
+static void MyDeleteFunc(void* kernel) {
+ struct MyCustomKernel* s = static_cast(kernel);
+ EXPECT_TRUE(s->created);
+ EXPECT_TRUE(s->compute_called);
+ delete_called = true;
+ delete s;
+}
+
+namespace tensorflow {
+
+static std::unique_ptr GetFakeKernel(const char* device_name,
+ const char* op_name,
+ Status* status) {
+ NodeDef def;
+ def.set_op(op_name);
+ def.set_device(device_name);
+ def.add_input("input1");
+ def.add_input("input2");
+ return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
+ status);
+}
+
+// Tests registration of a single C kernel and checks that calls through the
+// C/C++ boundary are being made.
+TEST(TestKernel, TestRegisterKernelBuilder) {
+ const char* kernel_name = "SomeKernelName";
+ const char* op_name = "FooOp";
+ const char* device_name = "FakeDeviceName1";
+
+ REGISTER_OP(op_name)
+ .Input("input1: double")
+ .Input("input2: uint8")
+ .Output("output1: uint8");
+
+ TF_KernelBuilder* builder = TF_NewKernelBuilder(
+ op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
+
+ {
+ TF_Status* status = TF_NewStatus();
+ TF_RegisterKernelBuilder(kernel_name, builder, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ KernelList list;
+ list.ParseFromArray(buf->data, buf->length);
+ ASSERT_EQ(1, list.kernel_size());
+ ASSERT_EQ(device_name, list.kernel(0).device_type());
+ TF_DeleteBuffer(buf);
+ TF_DeleteStatus(status);
+ }
+
+ {
+ Status status;
+ std::unique_ptr kernel =
+ GetFakeKernel(device_name, op_name, &status);
+ TF_EXPECT_OK(status);
+ ASSERT_NE(nullptr, kernel.get());
+ kernel->Compute(nullptr);
+ }
+
+ ASSERT_TRUE(delete_called);
+}
+
+class DummyDevice : public DeviceBase {
+ public:
+ DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
+ bool RequiresRecordingAccessedTensors() const override { return save_; }
+ Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
+ return cpu_allocator();
+ }
+
+ private:
+ bool save_;
+};
+
+TEST(TestKernel, TestInputAndOutputCount) {
+ const char* kernel_name = "InputOutputCounterKernel";
+ const char* op_name = "BarOp";
+ const char* device_name = "FakeDeviceName2";
+
+ REGISTER_OP(op_name)
+ .Input("input1: double")
+ .Input("input2: uint8")
+ .Output("output1: uint8");
+
+ static int num_inputs = 0;
+ static int num_outputs = 0;
+
+ // A kernel whose Compute function has a side-effect of updating num_inputs
+ // and num_outputs. Various functions on TF_OpKernelContext are also
+ // exercised.
+ auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+ num_inputs = TF_NumInputs(ctx);
+ num_outputs = TF_NumOutputs(ctx);
+
+ TF_Tensor* input = nullptr;
+ TF_Status* s = TF_NewStatus();
+ TF_GetInput(ctx, 0, &input, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
+ EXPECT_EQ(123, *static_cast(TF_TensorData(input)));
+ TF_GetInput(ctx, -1, &input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+ TF_GetInput(ctx, 3, &input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+
+ // Copy the input tensor to output.
+ TF_SetOutput(ctx, 0, input, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s));
+
+ TF_SetOutput(ctx, 24, input, s);
+ EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
+
+ TF_DeleteStatus(s);
+ if (input != nullptr) {
+ TF_DeleteTensor(input);
+ }
+ };
+
+ TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
+ my_compute_func, nullptr);
+
+ {
+ TF_Status* status = TF_NewStatus();
+ TF_RegisterKernelBuilder(kernel_name, builder, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_DeleteStatus(status);
+ }
+
+ {
+ OpKernelContext::Params p;
+ DummyDevice dummy_device(nullptr, false);
+ p.device = &dummy_device;
+
+ Tensor t(tensorflow::uint8(123));
+
+ gtl::InlinedVector inputs;
+ // Simulate 2 inputs
+ inputs.emplace_back(&t);
+ inputs.emplace_back();
+ p.inputs = &inputs;
+
+ Status status;
+ std::unique_ptr kernel =
+ GetFakeKernel(device_name, op_name, &status);
+ TF_EXPECT_OK(status);
+ ASSERT_NE(nullptr, kernel.get());
+
+ p.op_kernel = kernel.get();
+ OpKernelContext ctx(&p);
+ kernel->Compute(&ctx);
+
+ ASSERT_EQ(2, num_inputs);
+ ASSERT_EQ(1, num_outputs);
+ ASSERT_EQ(123, ctx.mutable_output(0)->scalar()());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}
+void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
+ TF_Status* status) {
+ mutex_lock l(graph->mu);
+ status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
+ new_src.index, &dst->node);
+ if (status->status.ok()) {
+ // This modification only updates the destination node for
+ // the purposes of running this graph in a session. Thus, we don't
+ // record the source node as being modified.
+ RecordMutation(graph, *dst, "adding input tensor");
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
+// Updates 'dst' to consume 'new_src'.
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status);
@@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// because I couldn't get SWIG to work otherwise.
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status);
+
+// This method is used to add a new input edge to 'dst', which must be a While
+// op. The While op's "T" attribute must have already been updated to include
+// the new edge. This is used to construct tf.while_loop gradients.
+void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
+ TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 83353b79f722f0a95f508b32d4a49b14b35624fb..a09becc49b10d2c58f98fbcc11df5190f794c1d4 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -489,6 +489,7 @@ tf_gen_op_wrappers_cc(
"image_ops",
"io_ops",
"linalg_ops",
+ "list_ops",
"logging_ops",
"lookup_ops",
"manip_ops",
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 3d3895c8fa82c3c0e2974228e9cad767d0e00df4..52345a376cc29ee47ccb9888c9bb26292468b5a9 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -133,5 +133,6 @@ filegroup(
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
+ "testdata/half_plus_two_v2/**",
]),
)
diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h
index 645a3f101d1ae7dda88ec4ca622c694dc5a7a919..6f00dc324bd7054b28de2c35023581e1666bfa01 100644
--- a/tensorflow/cc/saved_model/constants.h
+++ b/tensorflow/cc/saved_model/constants.h
@@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
/// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
-/// SavedModel legacy init op key.
+/// SavedModel legacy init op collection key. Used in v1 SavedModels.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
-/// SavedModel main op key.
+/// SavedModel main op collection key. Used in v1 SavedModels.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
/// Directory in which to save the SavedModel variables.
@@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
/// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables";
+/// SavedModel SignatureDef keys for the initialization and train ops. Used in
+/// V2 SavedModels.
+constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
+constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
+
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e..85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -122,34 +122,54 @@ Status RunOnce(const RunOptions& run_options,
return run_status;
}
-bool HasMainOp(const MetaGraphDef& meta_graph_def) {
+// RunInitOp will return OK if the initialization op was run successfully.
+// An empty init_op_name indicates that there are no init ops to run.
+Status RunInitOp(const RunOptions& run_options, const string& export_dir,
+ const MetaGraphDef& meta_graph_def,
+ const std::vector& asset_file_defs,
+ Session* session, const string& init_op_name) {
+ if (!init_op_name.empty()) {
+ LOG(INFO) << "Running initialization op on SavedModel bundle.";
+ std::vector> inputs;
+ AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
+ RunMetadata run_metadata;
+ return RunOnce(run_options, inputs, {}, {init_op_name},
+ nullptr /* outputs */, &run_metadata, session);
+ }
+ return Status::OK();
+}
+
+// A SavedModel may store the name of the initialization op to run in the
+// in the SignatureDef (v2) or a collection (v1). If an init_op collection
+// exists, then the collection must contain exactly one op.
+Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
+ string* init_op_name) {
+ const auto& sig_def_map = meta_graph_def.signature_def();
+ const auto& init_op_sig_it =
+ meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
+ if (init_op_sig_it != sig_def_map.end()) {
+ *init_op_name = init_op_sig_it->second.outputs()
+ .find(kSavedModelInitOpSignatureKey)
+ ->second.name();
+ return Status::OK();
+ }
+
const auto& collection_def_map = meta_graph_def.collection_def();
+ string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
- return true;
+ init_op_collection_key = kSavedModelMainOpKey;
+ } else {
+ init_op_collection_key = kSavedModelLegacyInitOpKey;
}
- return false;
-}
-Status RunMainOp(const RunOptions& run_options, const string& export_dir,
- const MetaGraphDef& meta_graph_def,
- const std::vector& asset_file_defs,
- Session* session, const string& main_op_key) {
- LOG(INFO) << "Running MainOp with key " << main_op_key
- << " on SavedModel bundle.";
- const auto& collection_def_map = meta_graph_def.collection_def();
- const auto main_op_it = collection_def_map.find(main_op_key);
- if (main_op_it != collection_def_map.end()) {
- if (main_op_it->second.node_list().value_size() != 1) {
+ const auto init_op_it = collection_def_map.find(init_op_collection_key);
+ if (init_op_it != collection_def_map.end()) {
+ if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
- std::vector> inputs;
- AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
- RunMetadata run_metadata;
- const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return RunOnce(run_options, inputs, {}, {string(main_op_name)},
- nullptr /* outputs */, &run_metadata, session);
+ *init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
@@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector* asset_file_defs) {
+ // With SavedModel v2, we write asset file def into metagraph instead of
+ // collection, so read from metagraph first.
+ if (meta_graph_def.asset_file_def_size() > 0) {
+ for (const auto& asset : meta_graph_def.asset_file_def()) {
+ asset_file_defs->push_back(asset);
+ }
+ return Status::OK();
+ }
+ // Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
@@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
- if (HasMainOp(bundle->meta_graph_def)) {
- TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
- bundle->meta_graph_def, asset_file_defs,
- bundle->session.get(), kSavedModelMainOpKey));
- } else {
- TF_RETURN_IF_ERROR(RunMainOp(
- run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
- bundle->session.get(), kSavedModelLegacyInitOpKey));
- }
+ string init_op_name;
+ TF_RETURN_IF_ERROR(
+ GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
+ TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
+ asset_file_defs, bundle->session.get(),
+ init_op_name));
return Status::OK();
}
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2..597e42bb65ab5536664089f7e65ec52d77fc8f23 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
+constexpr char kTestDataInitOpV2[] =
+ "cc/saved_model/testdata/half_plus_two_v2/00000123";
class LoaderTest : public ::testing::Test {
protected:
@@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
}
+TEST_F(LoaderTest, SavedModelInitOpV2Format) {
+ SavedModelBundle bundle;
+ SessionOptions session_options;
+ RunOptions run_options;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
+ TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
+ {kSavedModelTagServe}, &bundle));
+ CheckSavedModelBundle(export_dir, bundle);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f9ff036688007836524129e23f5cf82edd1e8910
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
@@ -0,0 +1 @@
+asset-file-contents
\ No newline at end of file
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..a10bbf8fb6bca0fcee6414b2927d2f706de85ebc
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000000000000000000000000000000000..15b75d6ef6bffc336d138d923badb3928b8c4c13
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..7ec9fb4fe2dd21d0a6c324aecd7658fc37cf2326
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index differ
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index b17bc658fa06b9feb7edb292bd89ef31e6309169..ab1c1be344e2257721507543bc7647d4ff4becb2 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
-Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
+Status GenArgMethods(const tf2xla::Config& config,
+ const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() != num_args) {
@@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
}
for (int i = 0; i < num_args; ++i) {
std::vector> rewrites;
- TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
+ TF_RETURN_IF_ERROR(
+ AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
- void set_arg{{NAME}}_data(void* data) {
+ void set_arg{{NAME}}_data(const void* data) {
set_arg_data({{I}}, data);
}
{{TYPE}}* arg{{NAME}}_data() {
@@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
// Generate methods for results (outputs).
Status GenResultMethods(const tf2xla::Config& config,
- const xla::ProgramShape& ps, string* methods) {
+ const xla::ProgramShapeProto& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// The XlaCompiler we use to build the xla computation always generates a
// tuple result, and we rely on this to simplify code generation.
@@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
}
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector> rewrites;
- TF_RETURN_IF_ERROR(
- AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
+ TF_RETURN_IF_ERROR(AddRewritesForShape(
+ i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
@@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
ExtractEntryParamBufferInfos(buffer_infos);
std::vector buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
- const xla::ProgramShape& ps = compile_result.program_shape;
+ const xla::ProgramShapeProto& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
@@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
@@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
- {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
@@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
Status GenerateMetadata(const CodegenOpts& opts,
const CompileResult& compile_result,
MetadataResult* metadata_result) {
- std::unique_ptr program_shape;
+ std::unique_ptr program_shape;
if (opts.gen_program_shape) {
program_shape =
- absl::make_unique(compile_result.program_shape);
+ absl::make_unique(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
@@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
// a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{
- CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
- program_shape.get()};
+ CreateUniqueIdentifier(opts, "ProgramShapeProto"),
+ "xla::ProgramShapeProto", program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 90410c46a8e36e44454f1219ad76d0fb0937070d..9485e86b10e225a3c9c12eafd9905bdf7c15c9fa 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -57,7 +57,7 @@ struct MetadataResult {
std::vector header_variable_decls;
// program_shape_access_shim is a C++ expression that constructs the
- // xla::ProgramShape instance for the CompileResult passed to
+ // xla::ProgramShapeProto instance for the CompileResult passed to
// GenerateMetadata.
string program_shape_access_shim;
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index bb288d23000527be74f01630d20bbf82e50007ce..c1788ca32a1d099284eeb870f9513891051fd29e 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
- compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
- {
- xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
- xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
- },
- xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
+ compile_result.program_shape =
+ xla::ShapeUtil::MakeProgramShape(
+ {
+ xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
+ xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
+ },
+ xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
+ .ToProto();
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index e4d8a02877c75fa72c5747650ab9c7ac229955b3..968afad65ed6d4b5510687df484b7ce6743f6a85 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -22,7 +22,7 @@ extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters);
-extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
+extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
namespace foo {
@@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
- void set_arg0_data(void* data) {
+ void set_arg0_data(const void* data) {
set_arg_data(0, data);
}
float* arg0_data() {
@@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1];
}
- void set_arg_myfeed_data(void* data) {
+ void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data);
}
float* arg_myfeed_data() {
@@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1];
}
- void set_arg1_data(void* data) {
+ void set_arg1_data(const void* data) {
set_arg_data(1, data);
}
tensorflow::int64* arg1_data() {
@@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = []() {
- xla::ProgramShape* proto = new xla::ProgramShape;
- proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = []() {
+ xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
+ proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
return proto;
}();
return kShape;
diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden
index eb001c5d45bdfefc76629d7303d89f5480432235..ce8e5ec8c96a2c3696f14b8eea206d648182ecb5 100644
Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 2b5f97b34cd928d32eb220536342c715d91d45bb..9fc223bdc7c0e207ce2005cb86250aa77e709df8 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client,
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
- compile_result->program_shape = *pshape_or.ValueOrDie();
- xla::ProgramShape* pshape = &compile_result->program_shape;
- std::vector arg_layouts;
- arg_layouts.reserve(pshape->parameters_size());
+ compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
+ xla::ProgramShapeProto* pshape = &compile_result->program_shape;
+
+ // AotXlaComputationInstance::argument_layouts is a vector of Shape
+ // pointers. Accumulate the Shape objects themselves in a separate vector
+ // while building the vector of pointers.
+ std::vector arg_layout_ptrs(pshape->parameters_size());
+ std::vector arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
- arg_layouts.push_back(pshape->mutable_parameters(i));
+ arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
+ arg_layout_ptrs[i] = &arg_layouts[i];
}
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
- instance.argument_layouts = std::move(arg_layouts);
- instance.result_layout = &pshape->result();
+ instance.argument_layouts = std::move(arg_layout_ptrs);
+ xla::Shape result_shape(pshape->result());
+ instance.result_layout = &result_shape;
xla::StatusOr>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
index e03c5b1aa77c1262ed903aae3072ef65f34d80a2..ee7bb26fabd2d897b85b62f38778ecbfe2238eb6 100644
--- a/tensorflow/compiler/aot/compile.h
+++ b/tensorflow/compiler/aot/compile.h
@@ -33,9 +33,9 @@ namespace tfcompile {
struct CompileResult {
// Contains object file and meta-info.
std::unique_ptr aot;
- xla::ProgramShape program_shape; // Static shape of args and results.
- string entry_point; // Name of generated function.
- int pointer_size = 0; // Size of a pointer in bytes.
+ xla::ProgramShapeProto program_shape; // Static shape of args and results.
+ string entry_point; // Name of generated function.
+ int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph_def into an object file containing a function
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index f10852c7850f61bfd8b99fa9f1648202d182085e..4dd79e5882d7da61be029735ef2b165908c599f9 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) {
// muladd has the program shape defined.
MatMulAndAddComp muladd;
- const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
+ const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
- EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
- EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
+ EXPECT_TRUE(
+ ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
+ EXPECT_TRUE(
+ ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
- const xla::Shape& muladd_result = muladd_shape->result();
+ const xla::Shape muladd_result(muladd_shape->result());
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
const xla::Shape& muladd_result0 =
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 162a137fa7a5573056911d19472de4261574137a..15dcbb2641eca031e82db9aa58dee6a14ab0a2cc 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -23,7 +23,6 @@ package(
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,7 +37,7 @@ cc_library(
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
- ] + if_cuda_is_configured([
+ ] + if_cuda([
":xla_gpu_device",
":xla_gpu_jit",
]),
@@ -51,6 +50,7 @@ cc_library(
deps = [
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
@@ -76,10 +76,11 @@ cc_library(
srcs = ["xla_cpu_device.cc"],
visibility = [":friends"],
deps = [
+ ":create_xla_launch_op", # buildcleaner: keep
+ ":flags",
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
- "//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
@@ -95,6 +96,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
+ ":create_xla_launch_op", # buildcleaner: keep
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
@@ -104,6 +106,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = 1,
)
@@ -210,6 +213,18 @@ cc_library(
# Internal targets below this point.
+cc_library(
+ name = "flags",
+ srcs = ["flags.cc"],
+ hdrs = ["flags.h"],
+ visibility = [":friends"],
+ deps = [
+ "//tensorflow/compiler/xla:parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_library(
name = "common",
srcs = [
@@ -256,6 +271,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
@@ -268,6 +284,7 @@ cc_library(
"//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@@ -487,6 +504,7 @@ cc_library(
deps = [
":common",
":encapsulate_util",
+ ":flags",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
@@ -494,8 +512,6 @@ cc_library(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
- "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags",
- "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
@@ -544,25 +560,6 @@ cc_library(
hdrs = ["union_find.h"],
)
-cc_library(
- name = "producer_consumer_queue",
- hdrs = ["producer_consumer_queue.h"],
- deps = ["//tensorflow/core:lib"],
-)
-
-tf_cc_test(
- name = "producer_consumer_queue_test",
- size = "small",
- srcs = ["producer_consumer_queue_test.cc"],
- deps = [
- ":producer_consumer_queue",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ],
-)
-
tf_cc_test(
name = "deadness_analysis_test",
size = "small",
@@ -743,7 +740,10 @@ tf_custom_op_py_library(
visibility = [
":friends",
],
- deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+ deps = [
+ "//tensorflow/compiler/jit/ops:xla_ops_grad",
+ "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
+ ],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 93637a69d5d7b6bf9e9ce784ae521ef0e9b121b9..9f4042630edaec1b9519b6434d859a48372e8b15 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -320,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
return IsXlaCompiledKernel(*n);
});
- bool lazy_compilation_enabled = enable_lazy_compilation_
- ? *enable_lazy_compilation_
- : legacy_flags::GetBuildXlaOpsPassFlags()
- .tf_xla_enable_lazy_compilation;
+ bool lazy_compilation_enabled =
+ enable_lazy_compilation_
+ ? *enable_lazy_compilation_
+ : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index 11df946cc186660242574c2644463a26ead44f1f..48a23a4c1711ac88a329723c46559112d5a39dbd 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
.ok());
}
- void TearDown() override {
- for (Device* device : devices_) {
- delete device;
- }
- }
-
private:
- std::vector devices_;
+ std::vector> devices_;
};
using ::tensorflow::testing::FindNodeByName;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
index 73866607621cd745f6e640a14405daebf0dd9985..0f872a480f4d4843217f1df3452c4dc62531264e 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
+ std::vector> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
+ options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) {
@@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
lib_def_ = absl::make_unique(
OpRegistry::Global(), proto);
OptimizerOptions opts;
- device_mgr_ = absl::make_unique(devices_);
+ device_mgr_ = absl::make_unique(std::move(devices));
pflr_ = absl::make_unique(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
}
FunctionLibraryRuntime* flr_;
- std::vector devices_;
std::unique_ptr device_mgr_;
std::unique_ptr lib_def_;
std::unique_ptr pflr_;
diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc
index 28ec37b1b9c8a1a306b5e778bac5b6ba01c2c997..1f4b9c90a4ff0b1166cdb7b5942771b350740ef3 100644
--- a/tensorflow/compiler/jit/encapsulate_util.cc
+++ b/tensorflow/compiler/jit/encapsulate_util.cc
@@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
continue;
} else if (src_xla_computation && !dst_xla_computation) {
if (src_outside_compilation) {
- // Case 1d: outside compilation to host computation control edge.
+ // Case 1c: outside compilation to host computation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr(
@@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
}
} else if (!src_xla_computation && dst_xla_computation) {
if (dst_outside_compilation) {
- // Case 1d: host computation control to outside compilation edge.
+ // Case 1c: host computation control to outside compilation edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr(
@@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
} else { // src_xla_computation && dst_xla_computation
if (*src_xla_computation != *dst_xla_computation) {
if (src_outside_compilation && dst_outside_compilation) {
- // Case 1c: outside compilation to outside compilation control edge.
+ // Case 1b: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
} else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1b: outside compilation to another XLA computaition control
+ // Case 1a: outside compilation to another XLA computaition control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr(
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
*dst_xla_computation));
} else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1b: another XLA computaition to outside compilation control
+ // Case 1a: another XLA computaition to outside compilation control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr(
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
*src_xla_computation));
}
- } else { // *src_xla_computation == *dst_xla_computation
- if (src_outside_compilation && dst_outside_compilation) {
- if (*src_outside_compilation != *dst_outside_compilation) {
- // Case 1c: outside compilation to outside compilation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1a: outside compilation to its XLA computation control edge.
- ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
- } else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1a: XLA computation to outside compilation in it control edge.
- ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
- }
}
}
}
@@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g,
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
}
- } else { // *src_xla_computation == *dst_xla_computation
- if (src_outside_compilation && dst_outside_compilation &&
- *src_outside_compilation != *dst_outside_compilation) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
- VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
- }
}
}
@@ -263,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input.
- std::map placeholders;
+ std::map, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e;
@@ -275,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Find or create placeholder node.
string new_name =
edges[i].is_host_to_outside_compilation
- ? absl::StrCat(src->name(), "_host_to_oc_placeholder")
- : absl::StrCat(src->name(), "_oc_to_host_placeholder");
- auto iter = placeholders.find(new_name);
+ ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
+ : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
+ auto placeholder_index = std::make_pair(src->name(), src_output);
+ auto iter = placeholders.find(placeholder_index);
Node* placeholder_node;
if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
@@ -310,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
Status s;
placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s);
- placeholders[new_name] = placeholder_node;
+ placeholders[placeholder_index] = placeholder_node;
} else {
placeholder_node = iter->second;
}
@@ -594,14 +573,244 @@ Status AddControlDependencies(
return Status::OK();
}
+// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges to remove. We should not remove the edge while iterating.
+ std::vector edges_to_remove;
+ for (const Edge* e : g->edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation) {
+ if (*src_outside_compilation != *dst_outside_compilation) {
+ // Case 1a: outside compilation to outside compilation control edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ e->src()->name()));
+ }
+ } else if (src_outside_compilation && !dst_outside_compilation) {
+ // Case 1b: outside compilation to its XLA computation control edge.
+ ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
+ } else if (!src_outside_compilation && dst_outside_compilation) {
+ // Case 1b: XLA computation to outside compilation in it control edge.
+ ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
+ }
+ }
+
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges between outside compilation and host computation. Notice that
+ // we do not store `Edge*` directly because we remove some nodes while adding
+ // Identity nodes, and those Edge pointers might be invalidated.
+ struct EdgeInfo {
+ int dst_input, dst_node_id;
+ };
+ std::vector edges;
+ for (const Edge* e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation &&
+ *src_outside_compilation != *dst_outside_compilation) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
+ VLOG(4) << "Oc -> oc edge: " << e->DebugString();
+ }
+ }
+
+ // Remove the edge from host to outside compilation. Add a placeholder as
+ // outside compilation node input.
+ std::map, Node*> placeholders;
+ for (int i = 0; i < edges.size(); i++) {
+ Node* dst = g->FindNodeId(edges[i].dst_node_id);
+ const Edge* e;
+ TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
+ Node* src = e->src();
+ int src_output = e->src_output(), dst_input = e->dst_input();
+ g->RemoveEdge(e);
+
+ // Find or create placeholder node.
+ string new_name =
+ absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
+ auto placeholder_index = std::make_pair(src->name(), src_output);
+ auto iter = placeholders.find(placeholder_index);
+ Node* placeholder_node;
+ if (iter == placeholders.end()) {
+ NodeDefBuilder placeholder_builder(new_name, "Placeholder");
+ placeholder_builder.Attr("dtype", src->output_type(src_output));
+ string outside_compilation_attr;
+ TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
+ outside_compilation_attr_name,
+ &outside_compilation_attr));
+ placeholder_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_attr);
+ placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
+ src->name());
+ placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
+ src_output);
+ NodeDef placeholder_def;
+ TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
+ Status s;
+ placeholder_node = g->AddNode(placeholder_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ placeholders[placeholder_index] = placeholder_node;
+ } else {
+ placeholder_node = iter->second;
+ }
+ g->AddEdge(placeholder_node, 0, dst, dst_input);
+
+ // Replace `e->dst()` because its input node changed.
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(dst_input) = placeholder_node->name();
+ TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
+
+ // Other edge in `edges` might have `e->dst()` as src or dst
+ // node. Before removing `e->dst()`, replace those edges with
+ // corresponding edges for `dst_replace_node`.
+ for (int j = i + 1; j < edges.size(); j++) {
+ if (edges[j].dst_node_id == edges[i].dst_node_id) {
+ edges[j].dst_node_id = dst_replace_node->id();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather all outside compilation to outside compilation nodes.
+ std::vector placeholder_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "Placeholder" &&
+ HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
+ placeholder_nodes.push_back(n);
+ }
+ }
+
+ // Remove the placeholder nodes, and reconnect original edge.
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto n : placeholder_nodes) {
+ string node_name;
+ int node_src_output;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
+ auto iter = node_name_index.find(node_name);
+ if (iter == node_name_index.end()) {
+ return errors::Internal(
+ "Cannot find original node for oc -> host placeholder node ",
+ node_name);
+ }
+
+ // Change all usage node to use the original node instead.
+ Node* original_node = iter->second;
+ std::vector control_edges;
+ std::vector data_edges;
+ for (auto e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges.push_back(e);
+ } else {
+ data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
+ }
+ }
+ for (const Edge* e : control_edges) {
+ g->AddControlEdge(original_node, e->dst());
+ g->RemoveEdge(e);
+ }
+ for (int i = 0; i < data_edges.size(); i++) {
+ Node* dst = data_edges[i].dst;
+ NodeDef new_def = dst->def();
+ int dst_input = data_edges[i].dst_input;
+ *new_def.mutable_input(dst_input) =
+ absl::StrCat(original_node->name(), ":", node_src_output);
+ TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
+
+ const Edge* edge_to_replace = nullptr;
+ TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
+ g->RemoveEdge(edge_to_replace);
+ g->AddEdge(original_node, node_src_output, replace_node, dst_input);
+
+ // Other edges might have `dst` as dst node. Update those edges with
+ // `replace_node`.
+ for (int j = i + 1; j < data_edges.size(); j++) {
+ if (data_edges[j].dst == dst) {
+ data_edges[j].dst = replace_node;
+ }
+ }
+
+ // Other placeholder node might have `dst` as original node. Update
+ // `node_name_index` with `replace_node`.
+ node_name_index[replace_node->name()] = replace_node;
+ }
+
+ // Remove placeholder node.
+ g->RemoveNode(n);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ auto node_name_index = g->BuildNodeNameIndex();
+
+ // Reconnect outside compilation to outside compilation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector control_deps;
+ Status s =
+ GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = node_name_index.find(control_input);
+ if (iter == node_name_index.end()) {
+ return errors::Internal("Cannot find original node for ",
+ control_input);
+ }
+ g->AddControlEdge(iter->second, n);
+ }
+ }
+ }
+ return Status::OK();
+}
} // namespace
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
-const char kXlaConnectedToXlaComputationAttrName[] =
- "_xla_connected_to_xla_computation";
-const char kXlaConnectedFromXlaComputationAttrName[] =
- "_xla_connected_from_xla_computation";
const char kXlaConnectedToOtherXlaComputationAttrName[] =
"_xla_connected_to_other_xla_computation";
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
@@ -616,6 +825,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] =
"_xla_host_to_oc_node_name";
const char kHostToOutsideCompilationSrcOutputAttrName[] =
"_xla_host_to_oc_src_output";
+const char kXlaConnectedToXlaComputationAttrName[] =
+ "_xla_connected_to_xla_computation";
+const char kXlaConnectedFromXlaComputationAttrName[] =
+ "_xla_connected_from_xla_computation";
+const char kOutsideCompilationOriginalNodeAttrName[] =
+ "_xla_oc_to_oc_node_name";
+const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
+const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
+ "_xla_control_dependencies_within_xla_cluster";
Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
@@ -699,4 +917,39 @@ Status PostprocessForEncapsulation(
return Status::OK();
}
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Remove edges from source node to outside compilation nodes, and edges
+ // from outside compilation nodes to sink node.
+ std::vector edges_to_remove;
+ for (const Edge* e : g->source_node()->out_edges()) {
+ if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (const Edge* e : g->sink_node()->in_edges()) {
+ if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+
+ TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h
index 5e0c4bf6a0cc92d69209595e257989665404db6b..e363bc5754ac395bae262dc67a780a0173efaf5e 100644
--- a/tensorflow/compiler/jit/encapsulate_util.h
+++ b/tensorflow/compiler/jit/encapsulate_util.h
@@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
-// Attribute indicating that some ops in this node's XLA computation has control
-// dependency on this node. Attribute value will always be "true".
-extern const char kXlaConnectedToXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependency on some ops in
-// this node's XLA computation. Attribute value will always be "true".
-extern const char kXlaConnectedFromXlaComputationAttrName[];
-
// Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA
// computation names).
@@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
// int (src_output for original edge).
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
+// Attribute indicating that some ops in this node's XLA computation has control
+// dependency on this node. Attribute value will always be "true".
+extern const char kXlaConnectedToXlaComputationAttrName[];
+
+// Attribute indicating that this node has control dependency on some ops in
+// this node's XLA computation. Attribute value will always be "true".
+extern const char kXlaConnectedFromXlaComputationAttrName[];
+
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string
// (original input node name).
@@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
// for original edge).
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
-// Preprocesses the graph for encapsulation. It will perform the following
-// operations in order:
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// string (original input node name).
+extern const char kOutsideCompilationOriginalNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// int (src_output for original edge).
+extern const char kOutsideCompilationSrcOutputAttrName[];
+
+// Attribute indicating that this node has control dependencies on some other
+// nodes within the same XLA cluster. Attribute value will be a list of string
+// (node names).
+extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
+
+// Preprocesses edges between different XLA clusters for encapsulation. It will
+// perform the following operations in order:
//
-// 1a. For control edges between outside compilation and its XLA computation,
-// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
-// outside compilation node.
-// 1b. For control edges between outside compilation and another XLA
+// 1a. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node.
-// 1c. For control edges between different outside compilations, remove the edge
-// and add attr "kXlaControlDependenciesAttrName = src node name" to dst
-// node.
-// 1d. For control edges between outside compilation and host computation,
+// 1b. For control edges between different outside compilations (in different
+// XLA computations), remove the edge and add attr
+// "kXlaControlDependenciesAttrName = src node name" to dst node.
+// 1c. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst
@@ -146,26 +158,53 @@ struct XlaClusterInfo {
const std::map host_compute_core;
};
-// Postprocesses the graph for encapsulation. This function reverts what
-// `PreprocessForEncapsulation` did. It will perform the following operations in
-// order:
+// Postprocesses edges between different XLA clusters for encapsulation. This
+// function reverts what `PreprocessForEncapsulation` did. It will perform the
+// following operations in order:
//
// 1. Remove Placeholder nodes between outside compilation and host computation
// (created in `PreprocessForEncapsulation` step 3).
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
-// 3a. Reconnect control edges between different outside compilations (marked by
-// `PreprocessForEncapsulation` step 1c) and control edges between outside
-// compilation and host computation (marked by `PreprocessForEncapsulation`
-// step 1d).
-// 3b. Reconnect control edges between outside compilation and another XLA
-// computation (marked by `PreprocessForEncapsulation` step 1b).
-// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are
-// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`.
+// 3a. Reconnect control edges between outside compilation and another XLA
+// computation (marked by `PreprocessForEncapsulation` step 1a).
+// 3b. Reconnect control edges between different outside compilations (marked by
+// `PreprocessForEncapsulation` step 1b).
+// 3c. Reconnect control edges between outside compilation and host computation
+// (marked by `PreprocessForEncapsulation` step 1c).
Status PostprocessForEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map& clusters);
+// Preprocesses edges within the same XLA cluster. It will perform the following
+// operations in order:
+//
+// 0. Remove edges from source node to outside compilation nodes, and edges
+// from outside compilation nodes to sink node.
+// 1a. For edges between different outside compilation clusters, remove the edge
+// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
+// name" to dst node.
+// 1b. For control edges between outside compilation and its XLA computation,
+// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
+// outside compilation node.
+// 2. For data edges between different outside compilations, remove the edge
+// and create a Placeholder node as dst node's input.
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
+
+// Postprocesses edges within the same XLA cluster. This function reverts what
+// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
+// following operations in order:
+//
+// 1. Remove Placeholder nodes between different outside compilations (created
+// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
+// 2a. Reconnect control edges between different outside compilations (marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
+// Notice that control edges marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
+// They are handled in `RewriteOutsideCompilationSubgraphFn`.
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
index 7255df3112916b7abcc98ff8204efc8c02209b13..3b8b49cb92f3e453883a8e64e12ce3748a5173f6 100644
--- a/tensorflow/compiler/jit/encapsulate_util_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
identity4_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_oc", "0");
identity5_node->AddAttr("_xla", "1");
- // Case 1a: control edges between outside compilation and its XLA computation.
- g.AddControlEdge(add_node, identity0_node);
- g.AddControlEdge(identity0_node, identity1_node);
- // Case 1b: control edges between outside compilation and another XLA
+ // Case 1a: control edges between outside compilation and another XLA
// computation.
g.AddControlEdge(identity0_node, identity3_node);
g.AddControlEdge(identity1_node, identity4_node);
- // Case 1c: control edges between different outside compilations.
+ // Case 1b: control edges between different outside compilations.
g.AddControlEdge(identity0_node, identity4_node);
- // Case 1d: control edges between outside compilation and host computation.
+ // Case 1c: control edges between outside compilation and host computation.
g.AddControlEdge(const0_node, identity0_node);
g.AddControlEdge(identity0_node, identity2_node);
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
- // Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the
- // outside compilation node.
- EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
- kXlaConnectedFromXlaComputationAttrName));
- EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
- kXlaConnectedToXlaComputationAttrName));
- // Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
+ // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node.
std::vector attr;
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
@@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "0");
- // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
+ // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
- // Case 1d: add attr "_xla_control_deps = src node name" to dst node.
+ // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaControlDependenciesAttrName, &attr));
@@ -162,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
TEST(PreprocessForEncapsulationTest, DataEdges) {
// Build the graph:
// "const_0" and "const_1" in host computation
+ // "identityn0" = ("const_0", "const_1") in host computation 0
// "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1
- // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
+ // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
+ // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
+ // outside compilation 0
+ // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
+ // outside compilation 0
// "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
+ auto identityn0 =
+ ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
+ Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
+ auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
+ {identityn0[0], identityn0[1]});
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
Graph g(OpRegistry::Global());
@@ -189,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
*identity0_node = node_index["identity0"],
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
+ *add5_node = node_index["add5"],
+ *identityn1_node = node_index["identityn_1"],
*identity1_node = node_index["identity1"];
add0_node->AddAttr("_xla", "0");
add1_node->AddAttr("_xla", "0");
@@ -197,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add3_node->AddAttr("_xla", "1");
add4_node->AddAttr("_xla", "1");
add4_node->AddAttr("_oc", "0");
+ add5_node->AddAttr("_xla", "1");
+ add5_node->AddAttr("_oc", "0");
+ identityn1_node->AddAttr("_xla", "1");
+ identityn1_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "1");
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
@@ -214,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
EXPECT_NE(bridge_identity0_add4, nullptr);
// Step 3: add placeholder for edges between host computation and outside
// compilation.
- EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
- Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
+ EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
+ Node *add1_oc_to_host_placeholder =
+ node_index["add1_oc_to_host_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostOriginalNodeAttrName, &str));
EXPECT_EQ(str, "add1");
@@ -226,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add4_node = node_index["add4"];
ASSERT_NE(add4_node, nullptr);
EXPECT_EQ(add4_node->def().input(0),
- "bridge_identity0_add4_host_to_oc_placeholder");
+ "bridge_identity0_add4_host_to_oc_placeholder_0");
Node *identity0_host_to_oc_placeholder =
- node_index["bridge_identity0_add4_host_to_oc_placeholder"];
+ node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationOriginalNodeAttrName, &str));
EXPECT_EQ(str, "bridge_identity0_add4");
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationSrcOutputAttrName, &i));
EXPECT_EQ(i, 0);
+
+ // Check different placeholder nodes are created for different src_output.
+ Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
+ *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
+ EXPECT_NE(placeholder0, nullptr);
+ EXPECT_NE(placeholder1, nullptr);
+ // Check we only have 2 placeholder nodes created for "identityn_0".
+ int placeholder_count = 0;
+ for (Node *n : g.nodes()) {
+ if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
+ string attr;
+ TF_CHECK_OK(GetNodeAttr(
+ n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
+ if (attr == "identityn_0") {
+ ++placeholder_count;
+ }
+ }
+ }
+ EXPECT_EQ(placeholder_count, 2);
}
TEST(PostprocessForEncapsulationTest, ControlEdges) {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 2ce6fa73fc448ca83fa392aa909cb385453eb8b6..d334100aa4a915a87fb05d371e0e3379a7ee05f2 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors,
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
- "Undeclared output of XLA computation. A common cause of this error "
- "is variable initializers that depend on the XLA computation. Edge: ",
+ "Undeclared output of XLA computation. Some common causes of this "
+ "error are: 1) variable initializers that depend on the XLA "
+ "computation; 2) gradient computations that depend on the XLA "
+ "computation, which can be mitigated by moving gradient computations "
+ "inside XLA computation. Offending edge: ",
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
e->dst_input());
}
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index 8b3587c5087a0651c466f53f3709ba21e75dd273..e3c7e2f89be9b37b51a633dabb099969c181013f 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
// replace this node with compilation result node.
// 3) all outside compilation graphs.
Status ConstructHostGraph(
- const string& xla_cluster_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
const std::vector& outside_compilation_host_graphs,
FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) {
host_graph->reset(new Graph(fld));
@@ -476,6 +476,10 @@ Status ConstructHostGraph(
host_graph->get(),
std::unordered_set{(*host_graph)->sink_node()});
+ // Postprocess edges between different outside compilations.
+ TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
+ host_graph->get(), outside_compilation_attr_name));
+
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_",
@@ -801,6 +805,11 @@ Status ExtractOutsideCompilationForFunction(
},
&fbody));
std::unique_ptr fbody_deleter(fbody);
+
+ // Preprocess edges between different outside compilations. They will be
+ // restored in `ConstructHostGraph()`.
+ TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
+ fbody->graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
@@ -860,8 +869,9 @@ Status ExtractOutsideCompilationForFunction(
// Construct host graph.
if (!outside_compilation_host_graphs.empty()) {
- TF_RETURN_IF_ERROR(ConstructHostGraph(
- xla_cluster_name, outside_compilation_host_graphs, fld, host_graph));
+ TF_RETURN_IF_ERROR(
+ ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_host_graphs, fld, host_graph));
}
// Remove the outside compilation graphs from function library.
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index c5bd64f004ef98853955372680277e04c16bdc9e..bff956100da661b679b4557fce53671e6cef88c5 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].dim_size(), 1);
- // Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a
- // non-empty value, and "1" should have an empty value.
+ // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
+ // empty values.
string shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph,
- "_outside_compilation_shape_inference_cluster_0");
+ EXPECT_EQ(shape_inference_graph, "");
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
&shape_inference_graph));
EXPECT_EQ(shape_inference_graph, "");
// Check `shape_inference_graphs`.
- EXPECT_EQ(shape_inference_graphs.size(), 1);
- EXPECT_EQ(shape_inference_graphs[0],
- "_outside_compilation_shape_inference_cluster_0");
+ EXPECT_EQ(shape_inference_graphs.size(), 0);
// Check `host_graph`: verify we have key placeholder and sequencer.
Node *key_placeholder = nullptr, *sequencer = nullptr;
@@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
send_recv_nodes.push_back(n);
}
}
- EXPECT_EQ(num_send_from_host, 2);
- EXPECT_EQ(num_recv_at_host, 2);
+ EXPECT_EQ(num_send_from_host, 1);
+ EXPECT_EQ(num_recv_at_host, 1);
for (Node *n : send_recv_nodes) {
Node *input_node;
TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
new file mode 100644
index 0000000000000000000000000000000000000000..98e344b3a080aa8aab27cd41564a90427bac151e
--- /dev/null
+++ b/tensorflow/compiler/jit/flags.cc
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include // NOLINT
+
+#include "tensorflow/compiler/jit/flags.h"
+#include "tensorflow/compiler/xla/parse_flags_from_env.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+BuildXlaOpsPassFlags* build_ops_flags;
+DumpGraphFlags* dump_graph_flags;
+MarkForCompilationPassFlags* mark_for_compilation_flags;
+XlaDeviceFlags* device_flags;
+XlaOpsCommonFlags* ops_flags;
+
+std::vector* flag_list;
+std::once_flag flags_init;
+
+void AppendDumpGraphFlagsInternal(std::vector* flag_list) {
+ std::vector new_flags = {
+ Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix,
+ "Path prefix to which graphs dumped during debugging should be "
+ "written."),
+ };
+ flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
+}
+
+void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) {
+ std::vector new_flags = {
+ Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
+ "Control compilation of operators into XLA computations on CPU and "
+ "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
+ "things very likely to be improved; 2 = on for everything. "
+ "Experimental."),
+ Flag("tf_xla_min_cluster_size",
+ &mark_for_compilation_flags->tf_xla_min_cluster_size,
+ "Minimum number of operators in an XLA compilation. Ignored for "
+ "operators placed on an XLA device or operators explicitly marked "
+ "for compilation."),
+ Flag("tf_xla_max_cluster_size",
+ &mark_for_compilation_flags->tf_xla_max_cluster_size,
+ "Maximum number of operators in an XLA compilation."),
+ Flag("tf_xla_clustering_debug",
+ &mark_for_compilation_flags->tf_xla_clustering_debug,
+ "Dump graphs during XLA compilation."),
+ Flag("tf_xla_cpu_global_jit",
+ &mark_for_compilation_flags->tf_xla_cpu_global_jit,
+ "Enables global JIT compilation for CPU via SessionOptions."),
+ Flag("tf_xla_clustering_fuel",
+ &mark_for_compilation_flags->tf_xla_clustering_fuel,
+ "Places an artificial limit on the number of ops marked as "
+ "eligible for clustering."),
+ Flag("tf_xla_fusion_only",
+ &mark_for_compilation_flags->tf_xla_fusion_only,
+ "enable fusion of element-wise operations only using XLA when "
+ "global_jit_level is ON*.")};
+ flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
+}
+
+void AllocateAndParseFlags() {
+ build_ops_flags = new BuildXlaOpsPassFlags;
+ build_ops_flags->tf_xla_enable_lazy_compilation = true;
+
+ dump_graph_flags = new DumpGraphFlags;
+ dump_graph_flags->tf_dump_graph_prefix = "/tmp/";
+
+ mark_for_compilation_flags = new MarkForCompilationPassFlags;
+ mark_for_compilation_flags->tf_xla_auto_jit = 0;
+ mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
+ mark_for_compilation_flags->tf_xla_max_cluster_size =
+ std::numeric_limits::max();
+ mark_for_compilation_flags->tf_xla_clustering_debug = false;
+ mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
+ mark_for_compilation_flags->tf_xla_clustering_fuel =
+ std::numeric_limits::max();
+ mark_for_compilation_flags->tf_xla_fusion_only = false;
+
+ device_flags = new XlaDeviceFlags;
+ device_flags->tf_xla_compile_on_demand = false;
+
+ ops_flags = new XlaOpsCommonFlags;
+ ops_flags->tf_xla_always_defer_compilation = false;
+
+ flag_list = new std::vector({
+ Flag("tf_xla_enable_lazy_compilation",
+ &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
+
+ Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
+ "Switch a device into 'on-demand' mode, where instead of "
+ "autoclustering ops are compiled one by one just-in-time."),
+
+ Flag("tf_xla_always_defer_compilation",
+ &ops_flags->tf_xla_always_defer_compilation, ""),
+ });
+ AppendDumpGraphFlagsInternal(flag_list);
+ AppendMarkForCompilationPassFlagsInternal(flag_list);
+ xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
+}
+
+} // namespace
+
+const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ return *build_ops_flags;
+}
+
+DumpGraphFlags* GetDumpGraphFlags() {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ return dump_graph_flags;
+}
+
+MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ return mark_for_compilation_flags;
+}
+
+XlaDeviceFlags* GetXlaDeviceFlags() {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ return device_flags;
+}
+
+const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ return *ops_flags;
+}
+
+void AppendMarkForCompilationPassFlags(std::vector* flag_list) {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ AppendMarkForCompilationPassFlagsInternal(flag_list);
+}
+
+void AppendDumpGraphFlags(std::vector* flag_list) {
+ std::call_once(flags_init, &AllocateAndParseFlags);
+ AppendDumpGraphFlagsInternal(flag_list);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/flags.h
similarity index 57%
rename from tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
rename to tensorflow/compiler/jit/flags.h
index 79b47357a179d2d9e0d1b6bf9c9f814288bcd5e1..5ddea588eef5270880d91623dc05893da265960a 100644
--- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
-
-// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
+#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
+#define TENSORFLOW_COMPILER_JIT_FLAGS_H_
#include
@@ -24,15 +22,8 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// mark_for_compilation_pass module.
-void AppendMarkForCompilationPassFlags(
- std::vector* flag_list);
-// The values of flags associated with the XLA bridge's
-// mark_for_compilation_pass module.
+// Flags associated with the XLA bridge's mark_for_compilation_pass module.
struct MarkForCompilationPassFlags {
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
// computations on CPU and GPU devices. 0 = use
@@ -57,12 +48,56 @@ struct MarkForCompilationPassFlags {
// only using XLA.
};
-// Return a pointer to the MarkForCompilationPassFlags struct;
+// Flags associated with the XLA bridge's xla_device module.
+struct XlaDeviceFlags {
+ // Switch the CPU device into "on-demand" mode, where instead of
+ // autoclustering ops are compiled one by one just-in-time.
+ // Enabling this mode by a legacy flag is a temporary mechanism. When this
+ // feature is battle-tested, we will switch this to be a session option.
+ bool tf_xla_compile_on_demand;
+};
+
+// Flags common to the _Xla* ops and their kernels.
+struct XlaOpsCommonFlags {
+ // If true, _XlaCompile always refuses to compile the cluster, which means the
+ // XLA clusters always run in the TF executor. Defaults to false.
+ bool tf_xla_always_defer_compilation;
+};
+
+// Flags for the build_xla_ops pass.
+struct BuildXlaOpsPassFlags {
+ // Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
+ // Defaults to true.
+ bool tf_xla_enable_lazy_compilation;
+};
+
+// Flags for the XLA bridge's dump_graph module.
+struct DumpGraphFlags {
+ // Path prefix to which graphs dumped during debugging should be written.
+ string tf_dump_graph_prefix;
+};
+
+// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
+
+// Getters for flags structs defined above. The first call to any of these
+// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
+// always return the same pointer.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
+const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
+XlaDeviceFlags* GetXlaDeviceFlags();
+const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
+DumpGraphFlags* GetDumpGraphFlags();
+
+// Appends the flag definitions associated with
+// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
+//
+// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
+void AppendMarkForCompilationPassFlags(
+ std::vector* flag_list);
+void AppendDumpGraphFlags(std::vector* flag_list);
-} // namespace legacy_flags
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
+#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
index d984ca15cb722821b2a466a90387a29cbc1d1097..ce53f70b79d97ab087fefe542920b33f883632a2 100644
--- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
+++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -208,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope,
DCHECK_EQ(slice_size.back().type(), DT_INT64);
}
- *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
- ops::Const(host_scope.WithOpName("concat_axis"), 0));
+ // Trivial ConcatV2 nodes (with exactly one input) are disallowed.
+ *size =
+ slice_size.size() == 1
+ ? slice_size[0]
+ : ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
+ ops::Const(host_scope.WithOpName("concat_axis"), 0));
return Status::OK();
}
@@ -242,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
.WithOpName("static_shaped_slice"),
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
.node();
+
+ TF_RETURN_IF_ERROR(main_scope.status());
+
std::vector compile_time_const_inputs;
compile_time_const_inputs.push_back("size");
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
@@ -284,49 +291,45 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
return Status::OK();
}
-// If `n` is a slice we can rewrite to have a static shape (i.e. have the output
-// shape only depend on the "size" input) then returns the a SliceInputs
-// representing the inputs to `n`. Otherwise returns nullopt.
-StatusOrOptional IsRewritableSlice(Node* n) {
+// Return true if `n` is a slice we can rewrite to have a static shape
+// (i.e. have the output shape only depend on the "size" input).
+xla::StatusOr IsRewritableSlice(Node* n) {
if (n->type_string() != "Slice") {
- return {absl::nullopt};
+ return false;
}
if (!GetXlaClusterForNode(*n).has_value()) {
// There is no need to change slice ops outside XLA clusters.
- return {absl::nullopt};
+ return false;
}
TF_ASSIGN_OR_RETURN(absl::optional slice_inputs,
GetSliceInputs(n));
if (!slice_inputs.has_value()) {
- return {absl::nullopt};
+ return false;
}
// If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here.
- bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector,
- [](int64 size_i) { return size_i >= -1; });
- if (!slice_is_ok) {
- return {absl::nullopt};
- }
-
- return slice_inputs;
+ return absl::c_all_of(slice_inputs->size_as_vector,
+ [](int64 size_i) { return size_i >= -1; });
}
Status FindAndRewriteSlices(Graph* g, bool* changed) {
- std::vector> slices_to_rewrite;
+ std::vector slices_to_rewrite;
for (Node* n : g->nodes()) {
- TF_ASSIGN_OR_RETURN(absl::optional slice_inputs,
- IsRewritableSlice(n));
- if (slice_inputs.has_value()) {
- slices_to_rewrite.push_back({n, std::move(*slice_inputs)});
+ TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
+ if (is_rewritable) {
+ slices_to_rewrite.push_back(n);
}
}
- for (const auto& pair : slices_to_rewrite) {
- TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second,
- *GetXlaClusterForNode(*pair.first)));
+ for (Node* n : slices_to_rewrite) {
+ TF_ASSIGN_OR_RETURN(absl::optional slice_inputs,
+ GetSliceInputs(n));
+ TF_RET_CHECK(slice_inputs.has_value());
+ TF_RETURN_IF_ERROR(
+ RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n)));
}
if (!slices_to_rewrite.empty()) {
@@ -342,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) {
Status IncreaseDynamismForAutoJitPass::Run(
const GraphOptimizationPassOptions& options) {
- legacy_flags::MarkForCompilationPassFlags* flags =
- legacy_flags::GetMarkForCompilationPassFlags();
+ MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
index 0f6f612e967035f6af3e4aff2a499d5cedd018af..a2f1b831ad7605237e23c15cc43b337e06265553 100644
--- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
+++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
namespace tensorflow {
namespace {
+using ::testing::_;
using testing::matchers::AssignedDevice;
using testing::matchers::Attr;
using testing::matchers::Const;
@@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) {
EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
}
+TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) {
+ Scope root = Scope::NewRootScope()
+ .ExitOnError()
+ .WithAssignedDevice(kDeviceName)
+ .WithXlaCluster("cluster_0");
+
+ Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
+ Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
+ Output size = ops::Const(root.WithOpName("size"), {-1});
+ Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
+
+ std::unique_ptr result;
+ TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
+
+ Node* static_shaped_slice = testing::FindNodeByName(
+ result.get(), "slice/static_shaped_slice/static_shaped_slice");
+ EXPECT_NE(static_shaped_slice, nullptr);
+ EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2")))));
+}
+
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
Scope root = Scope::NewRootScope()
.ExitOnError()
@@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
}
+int64 ToInt64(int v) { return static_cast(v); }
+
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
- auto to_int64 = [](int v) { return static_cast(v); };
-
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
Output size =
- ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)});
+ ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr result;
@@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
+TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) {
+ Scope root = Scope::NewRootScope()
+ .ExitOnError()
+ .WithAssignedDevice(kDeviceName)
+ .WithXlaCluster("cluster_0");
+
+ Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
+ Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
+ Output size = ops::Const(root.WithOpName("size"), {});
+ Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
+
+ std::unique_ptr result;
+ TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
+
+ Node* static_shaped_slice = testing::FindNodeByName(
+ result.get(), "slice/static_shaped_slice/static_shaped_slice");
+ ASSERT_NE(static_shaped_slice, nullptr);
+ EXPECT_THAT(static_shaped_slice,
+ NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr),
+ Inputs(_, _, Out(NodeWith(Name(size.node()->name()))))));
+}
+
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
- auto to_int64 = [](int v) { return static_cast(v); };
+ auto ToInt64 = [](int v) { return static_cast(v); };
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
@@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
Output size =
- ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}});
+ ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}});
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
std::unique_ptr result;
@@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
+
+TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) {
+ Scope root = Scope::NewRootScope()
+ .ExitOnError()
+ .WithAssignedDevice(kDeviceName)
+ .WithXlaCluster("cluster_0");
+
+ Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
+ Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
+ Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500});
+ Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a);
+
+ Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200});
+ Output slice_with_slice_input = ops::Slice(
+ root.WithOpName("slice_with_slice_input"), slice, begin, size_b);
+
+ std::unique_ptr result;
+ TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
+
+ Node* static_shaped_slice = testing::FindNodeByName(
+ result.get(),
+ "slice_with_slice_input/static_shaped_slice/static_shaped_slice");
+ ASSERT_NE(static_shaped_slice, nullptr);
+ EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
+ << "Expected DT_FLOAT, was "
+ << DataType_Name(static_shaped_slice->output_type(0));
+ EXPECT_THAT(
+ static_shaped_slice,
+ NodeWith(
+ Op("Slice"),
+ Inputs(Out(NodeWith(
+ Op("Slice"),
+ Name("slice/static_shaped_slice/static_shaped_slice"))),
+ _, _)));
+}
+
+TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) {
+ Scope root = Scope::NewRootScope()
+ .ExitOnError()
+ .WithAssignedDevice(kDeviceName)
+ .WithXlaCluster("cluster_0");
+
+ Output input_float =
+ ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT);
+ Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64);
+
+ Output begin_begin =
+ ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32);
+ Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1});
+ Output begin =
+ ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size);
+
+ Output size =
+ ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)});
+ Output slice_with_slice_begin = ops::Slice(
+ root.WithOpName("slice_with_slice_begin"), input_float, begin, size);
+
+ std::unique_ptr result;
+ TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
+
+ Node* static_shaped_slice = testing::FindNodeByName(
+ result.get(),
+ "slice_with_slice_begin/static_shaped_slice/static_shaped_slice");
+ ASSERT_NE(static_shaped_slice, nullptr);
+ EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
+ << "Expected DT_FLOAT, was "
+ << DataType_Name(static_shaped_slice->output_type(0));
+ EXPECT_THAT(
+ static_shaped_slice,
+ NodeWith(
+ Op("Slice"),
+ Inputs(_,
+ Out(NodeWith(
+ Op("Slice"),
+ Name("begin/static_shaped_slice/static_shaped_slice"))),
+ _)));
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 830db9ebdd92608c375ad778eced833e26729325..0583774714c6db7a2fa515fc8a0d304e1898db97 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -12,10 +12,10 @@ cc_library(
hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
+ "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util",
- "//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 055de7afcc538a1a1183f3687d998a5b2211c887..ad71df5a694a5f8da94675049df1062a7edb6253 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -418,7 +418,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
cannot_compile_cluster = cannot_compile_cluster_;
}
- if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
+ if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
executable = nullptr;
} else {
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
deleted file mode 100644
index 5fa6c85f06f863f5d18bc4939ffa0ae820d222bd..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ /dev/null
@@ -1,65 +0,0 @@
-# Legacy command line flags for the XLA bridge libraries.
-
-# Please do not add more flags to this package.
-
-# The XLA bridge libraries were written in an environment that allowed
-# command-line flags to be scattered freely throughout the libraries. This
-# model, while initially convenient, leads to a proliferation in unused command
-# line flags in tests and binaries, and serious problems in servers, where one
-# might wish parameters to be different in independent RPC calls to the same
-# routine.
-#
-# Please don't add more flags. If you're a library author, pass options and
-# parameters explicitly through the library's interface.
-
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-cc_library(
- name = "mark_for_compilation_pass_flags",
- srcs = ["mark_for_compilation_pass_flags.cc"],
- hdrs = ["mark_for_compilation_pass_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
- name = "xla_device_flags",
- srcs = ["xla_device_flags.cc"],
- hdrs = ["xla_device_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
- name = "build_xla_ops_pass_flags",
- srcs = ["build_xla_ops_pass_flags.cc"],
- hdrs = ["build_xla_ops_pass_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
- name = "xla_ops_common_flags",
- srcs = ["xla_ops_common_flags.cc"],
- hdrs = ["xla_ops_common_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc
deleted file mode 100644
index 961c17c17eac891261530ef25baaa50f8496c331..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include // NOLINT
-
-#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
-#include "tensorflow/compiler/xla/parse_flags_from_env.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-namespace {
-
-BuildXlaOpsPassFlags* flags;
-std::vector* flag_list;
-std::once_flag flags_init;
-
-void AllocateAndParseFlags() {
- flags = new BuildXlaOpsPassFlags;
- flags->tf_xla_enable_lazy_compilation = true;
- flag_list = new std::vector({
- Flag("tf_xla_enable_lazy_compilation",
- &flags->tf_xla_enable_lazy_compilation, ""),
- });
- xla::ParseFlagsFromEnv(*flag_list);
-}
-
-} // namespace
-
-const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
- std::call_once(flags_init, &AllocateAndParseFlags);
- return *flags;
-}
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h
deleted file mode 100644
index 9aa5cf64d6db56ae36875ca08d2ae88c73604733..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Flags for the build_xla_ops pass.
-struct BuildXlaOpsPassFlags {
- // Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
- // Defaults to true.
- bool tf_xla_enable_lazy_compilation;
-};
-
-// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment
-// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
-// only the first time this routine is called.
-const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
deleted file mode 100644
index bad306e0b0a3061ba13dc69c08066c642667a2b9..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
-
-#include
-#include
-
-#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
-#include "tensorflow/compiler/xla/parse_flags_from_env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static MarkForCompilationPassFlags* flags;
-static std::vector* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new MarkForCompilationPassFlags;
- flags->tf_xla_auto_jit = 0;
- flags->tf_xla_min_cluster_size = 2;
- flags->tf_xla_max_cluster_size = std::numeric_limits::max();
- flags->tf_xla_clustering_debug = false;
- flags->tf_xla_cpu_global_jit = false;
- flags->tf_xla_clustering_fuel = std::numeric_limits::max();
- flags->tf_xla_fusion_only = false;
- flag_list = new std::vector(
- {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
- "Control compilation of operators into XLA computations on CPU and "
- "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
- "things very likely to be improved; 2 = on for everything. "
- "Experimental."),
- Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
- "Minimum number of operators in an XLA compilation. Ignored for "
- "operators placed on an XLA device or operators explicitly marked "
- "for compilation."),
- Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
- "Maximum number of operators in an XLA compilation."),
- Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
- "Dump graphs during XLA compilation."),
- Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
- "Enables global JIT compilation for CPU via SessionOptions."),
- Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
- "Places an artificial limit on the number of ops marked as "
- "eligible for clustering."),
- Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
- "enable fusion of element-wise operations only using XLA when "
- "global_jit_level is ON*.")});
- xla::ParseFlagsFromEnv(*flag_list);
-
- if (VLOG_IS_ON(1)) {
- VLOG(1) << "Parsed MarkForCompilationPassFlags:";
- VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
- VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size;
- VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size;
- VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug;
- VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
- VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel;
- VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
- }
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// mark_for_compilation_pass module.
-void AppendMarkForCompilationPassFlags(std::vector* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the MarkForCompilationPassFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc
deleted file mode 100644
index 76b80d3034c8a13a1ddf1afe548d5c3d9c7b2cec..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's xla_device module.
-
-#include
-#include
-
-#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
-#include "tensorflow/compiler/xla/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static XlaDeviceFlags* flags;
-static std::vector* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new XlaDeviceFlags;
- flags->tf_xla_compile_on_demand = false;
- flag_list = new std::vector({
- Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand,
- "Switch a device into 'on-demand' mode, where instead of "
- "autoclustering ops are compiled one by one just-in-time."),
- });
- xla::ParseFlagsFromEnv(*flag_list);
-}
-
-// Return a pointer to the XlaDeviceFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-XlaDeviceFlags* GetXlaDeviceFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h
deleted file mode 100644
index 27b22121ac1e089bd5d5a494e1e3fb60b05bc76d..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
-
-// Legacy flags for the XLA bridge's xla_device module.
-
-#include
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// The values of flags associated with the XLA bridge's
-// xla_device module.
-typedef struct {
- // Switch the CPU device into "on-demand" mode, where instead of
- // autoclustering ops are compiled one by one just-in-time.
- // Enabling this mode by a legacy flag is a temporary mechanism. When this
- // feature is battle-tested, we will switch this to be a session option.
- bool tf_xla_compile_on_demand;
-} XlaDeviceFlags;
-
-// Return a pointer to the XlaDeviceFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-XlaDeviceFlags* GetXlaDeviceFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc
deleted file mode 100644
index 1443d48a734c0a44c1cd91d8d1218bdbed7f765c..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include // NOLINT
-#include
-
-#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
-#include "tensorflow/compiler/xla/parse_flags_from_env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-XlaOpsCommonFlags* flags;
-std::vector* flag_list;
-std::once_flag flags_init;
-
-void AllocateAndParseFlags() {
- flags = new XlaOpsCommonFlags;
- flags->tf_xla_always_defer_compilation = false;
- flag_list = new std::vector({
- Flag("tf_xla_always_defer_compilation",
- &flags->tf_xla_always_defer_compilation, ""),
- });
- xla::ParseFlagsFromEnv(*flag_list);
-
- if (VLOG_IS_ON(1)) {
- VLOG(1) << "Parsed XlaOpsCommonFlags:";
- VLOG(1) << " tf_xla_always_defer_compilation = "
- << flags->tf_xla_always_defer_compilation;
- }
-}
-
-const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
- std::call_once(flags_init, &AllocateAndParseFlags);
- return *flags;
-}
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h
deleted file mode 100644
index 7c5c1818ef2d1dcf38c324a2c926db9c4bfa8ef5..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Flags common to the _Xla* ops and their kernels.
-struct XlaOpsCommonFlags {
- // If true, _XlaCompile always refuses to compile the cluster, which means the
- // XLA clusters always run in the TF executor. Defaults to false.
- bool tf_xla_always_defer_compilation;
-};
-
-// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment
-// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
-// only the first time this routine is called.
-const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 70033cae0afacb6a25598ee1abf2aeb2721e7496..6618e3a58ab7b6374ed775cd6e4e18a6a4975588 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -24,8 +24,8 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
-#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
@@ -72,6 +72,11 @@ struct OperationFilter {
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
// have dummy XLA implementations.
bool allow_dummy_ops;
+
+ // Whether ops that produce or consume DT_VARIANT values are allowed. We
+ // don't auto-cluster these ops because we don't yet support live-in or
+ // live-out DT_VARIANT values.
+ bool allow_ops_producing_or_consuming_variant;
};
bool IsDummyImplOp(absl::string_view op_name) {
@@ -81,7 +86,13 @@ bool IsDummyImplOp(absl::string_view op_name) {
bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
- op_name == "TruncatedNormal";
+ op_name == "TruncatedNormal" || op_name == "Multinomial";
+}
+
+bool OpProducesOrConsumesVariant(const Node& node) {
+ auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
+ return absl::c_any_of(node.input_types(), is_variant) ||
+ absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
@@ -246,6 +257,10 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
return false;
}
+ if (!op_filter.allow_ops_producing_or_consuming_variant &&
+ OpProducesOrConsumesVariant(*node)) {
+ return false;
+ }
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) {
@@ -427,8 +442,7 @@ Status FindCompilationCandidates(
BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
&compile_time_const_nodes));
- int64& fuel =
- legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
+ int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
// Iterate over nodes in sorted order so that compiler fuel is deterministic.
// We can't simply pass op_nodes().begin() and op_nodes().end to the
@@ -471,16 +485,15 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
DeviceType jit_device_type(registration->compilation_device_name);
+ bool always_auto_cluster = registration->autoclustering_policy ==
+ XlaOpRegistry::AutoclusteringPolicy::kAlways;
+
OperationFilter op_filter;
op_filter.allow_resource_ops = registration->compile_resource_ops;
- op_filter.allow_stateful_rng_ops =
- (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
- op_filter.allow_control_trigger =
- (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
- op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
+ op_filter.allow_stateful_rng_ops = always_auto_cluster;
+ op_filter.allow_control_trigger = always_auto_cluster;
+ op_filter.allow_dummy_ops = always_auto_cluster;
+ op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
@@ -504,6 +517,12 @@ Status FindCompilationCandidates(
<< node->type_string() << ")";
continue;
}
+ if (!op_filter.allow_ops_producing_or_consuming_variant &&
+ OpProducesOrConsumesVariant(*node)) {
+ VLOG(2) << "Rejecting " << node->name()
+ << ": produces or consumes DT_VARIANT";
+ continue;
+ }
if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
@@ -607,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
// To set compilation to be on by default, change the following line.
global_jit_level = OptimizerOptions::OFF;
}
- legacy_flags::MarkForCompilationPassFlags* flags =
- legacy_flags::GetMarkForCompilationPassFlags();
+ MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_auto_jit == -1 ||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
@@ -641,6 +659,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_dummy_ops = true;
+ op_filter.allow_ops_producing_or_consuming_variant = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
}
@@ -651,8 +670,7 @@ Status MarkForCompilationPass::Run(
// device ahead of time.
OptimizerOptions::GlobalJitLevel global_jit_level =
GetGlobalJitLevel(options);
- legacy_flags::MarkForCompilationPassFlags* flags =
- legacy_flags::GetMarkForCompilationPassFlags();
+ MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
bool fusion_only = flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
@@ -953,8 +971,7 @@ Status MarkForCompilationPass::RunImpl(
OptimizerOptions::GlobalJitLevel global_jit_level =
GetGlobalJitLevel(options);
- legacy_flags::MarkForCompilationPassFlags* flags =
- legacy_flags::GetMarkForCompilationPassFlags();
+ MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 24d78c077268f83cebbdafddc1a658ae8dc6b8d8..bf2c5508ea9e987e80093f4c2e15d3ff5191126f 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -1147,5 +1148,80 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
EXPECT_EQ(clusters["test/check"], "");
}
+TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
+ Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
+
+ Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
+ Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
+
+ Output tensor_list_reserve = ops::TensorListReserve(
+ root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
+
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
+}
+
+TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output dummy_input =
+ ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
+ Output variant_input =
+ ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
+
+ // Create one more node so that we don't avoid creating a cluster solely
+ // because it would be trivial.
+ Output dummy_cast =
+ ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
+
+ Output tensor_list_element_shape = ops::TensorListElementShape(
+ root.WithOpName("test/tensor_list_element_shape"), variant_input,
+ DT_INT32);
+
+ root.graph()->AddControlEdge(dummy_cast.node(),
+ tensor_list_element_shape.node());
+
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
+}
+
+TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
+ Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
+
+ Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
+ Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
+
+ Output tensor_list_reserve = ops::TensorListReserve(
+ root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
+
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(xla_cpu_device);
+ }
+ }
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/tensor_list_reserve"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index d56d0f8ccfcdab40003be38059228cb255921b64..64a3301745790132fe3149bf8fb52d6c45ecc3c1 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -34,15 +34,9 @@ namespace tensorflow {
//
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
// make this more direct, but probably not worth it solely for this test.
- std::vector devices;
+ std::vector> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
- auto delete_devices = gtl::MakeCleanup([&] {
- for (Device* d : devices) {
- delete d;
- }
- });
-
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index f72224545b25bc7100e0b6788e6fbf0a7ca63dad..64409d9334751e0edfce9091a4e5697dd2c712c5 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -18,3 +18,9 @@ tf_gen_op_wrapper_py(
out = "xla_ops.py",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
+
+py_library(
+ name = "xla_ops_grad",
+ srcs = ["xla_ops_grad.py"],
+ deps = ["//tensorflow/python:framework_ops"],
+)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/compiler/jit/ops/xla_ops_grad.py
similarity index 62%
rename from tensorflow/contrib/estimator/python/estimator/dnn.py
rename to tensorflow/compiler/jit/ops/xla_ops_grad.py
index 10f657df8de64cc96f0cf04f434a77df66629dca..2d31d8dc714307a48932d061fb1af643940a0872 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/compiler/jit/ops/xla_ops_grad.py
@@ -1,3 +1,4 @@
+"""Gradients for XLA ops."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,21 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""dnn python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow_estimator.contrib.estimator.python.estimator import dnn
+from tensorflow.python.framework import ops
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')]
-from tensorflow_estimator.contrib.estimator.python.estimator.dnn import *
+@ops.RegisterGradient("XlaClusterOutput")
+def _XlaClusterOutputGrad(_, grad):
+ del grad # unused
+ raise RuntimeError("Gradient computation of graph in xla.compile() is "
+ "prohibited because it can cause performance degradation."
+ "Please move gradient computation inside xla.compile().")
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 36b345ecbff8d5f6ba3c241b9e164f677236c20d..42ea3926e16ae791dbe1bede3b8742383db7667c 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -26,6 +26,10 @@ limitations under the License.
namespace tensorflow {
namespace {
+
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
+
+namespace reduce_device_to_host_copies {
Status FindNodesToDecluster(const Graph& graph,
absl::flat_hash_set* result,
absl::Span post_order) {
@@ -140,8 +144,6 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
-bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
-
// Clones nodes to outside their cluster to avoid device-to-host copies. For
// instance, converts this:
//
@@ -168,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
// where the ===> arrow has a hostmem source and destination and would entail a
// device to host copy if the source and destination were not in the same XLA
// cluster.
-Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
+Status PartiallyDeclusterGraph(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
@@ -206,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
return Status::OK();
}
+} // namespace reduce_device_to_host_copies
+namespace reduce_recompilation {
bool IsIntraClusterEdge(const Edge& edge) {
absl::optional src_cluster_name =
GetXlaClusterForNode(*edge.src());
@@ -269,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) {
// regress performance in any significant manner. We will have to revisit this
// algorith with a more complex cost model if this assumption turns out to be
// incorrect.
-Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+Status PartiallyDeclusterGraph(Graph* graph) {
std::vector compile_time_const_nodes(graph->num_node_ids());
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
@@ -322,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) {
return Status::OK();
}
-
+} // namespace reduce_recompilation
} // namespace
Status PartiallyDeclusterPass::Run(
@@ -334,8 +338,9 @@ Status PartiallyDeclusterPass::Run(
Graph* graph = options.graph->get();
- TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
- TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+ TF_RETURN_IF_ERROR(
+ reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
+ TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 1fc5da5071f7aa6f6dd6636aacd60e33c12431a6..38a54cc5efae35ad77b6dc8039c653e920cfc071 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device.
- std::vector devices;
+ std::vector> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
@@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
-
- for (Device* d : devices) {
- delete d;
- }
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h
deleted file mode 100644
index 7c8c04152d2f3a0fd46711df24756b7e68b967ea..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/producer_consumer_queue.h
+++ /dev/null
@@ -1,132 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
-#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
-
-#include
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace tensorflow {
-
-// A thread-safe, first-in-first-out queue.
-template
-class ProducerConsumerQueue {
- public:
- ProducerConsumerQueue()
- : capacity_(std::numeric_limits::max()) {}
- ~ProducerConsumerQueue() = default;
-
- // Wait until the queue is non-full, then append a copy of v.
- void Put(const T &v);
-
- // Wait until the queue is non-empty, then remove and return the head value.
- T Get();
-
- // If the queue is non-empty, remove the head value, placing it in *pv, and
- // return true; otherwise return false.
- bool TryGet(T *pv);
-
- // Set the capacity of the queue; the queue is full whenever count() >=
- // capacity(). The initial value is the maximum size_t. Requires size > 0.
- void set_capacity(std::size_t size);
-
- // Return the capacity of the queue.
- std::size_t capacity() const;
-
- // Return the number of elements in the queue.
- std::size_t count() const;
-
- // Implementation details follow. Clients should ignore.
- private:
- mutable tensorflow::mutex mu_; // protects all fields below
- tensorflow::condition_variable non_empty_ GUARDED_BY(mu_);
- tensorflow::condition_variable non_full_ GUARDED_BY(mu_);
- std::size_t capacity_ GUARDED_BY(mu_);
- std::deque queue_ GUARDED_BY(mu_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue);
-};
-
-// ------------------------------------------------------
-// Implementation details follow. Clients should ignore.
-
-// Wait until the queue is non-full, then append a copy of v.
-template
-void ProducerConsumerQueue::Put(const T &v) {
- mutex_lock lock(mu_);
- while (queue_.size() >= capacity_) {
- non_full_.wait(lock);
- }
- queue_.push_back(v);
- non_empty_.notify_one();
-}
-
-// Wait until the queue is non-empty, then remove and return the head value.
-template
-T ProducerConsumerQueue::Get() {
- mutex_lock lock(mu_);
- while (queue_.empty()) {
- non_empty_.wait(lock);
- }
- non_full_.notify_one();
- T result_value = queue_.front();
- queue_.pop_front();
- return result_value;
-}
-
-// If the queue is non-empty, remove the head value, placing it in *pv, and
-// return true; otherwise return false.
-template
-bool ProducerConsumerQueue::TryGet(T *pv) {
- mutex_lock lock(mu_);
- bool got_element = !queue_.empty();
- if (got_element) {
- non_full_.notify_one();
- *pv = queue_.front();
- queue_.pop_front();
- }
- return got_element;
-}
-
-// Set the capacity of the queue; the queue is full whenever count() >=
-// capacity(). The initial value is the maximum size_t. Requires size > 0.
-template
-void ProducerConsumerQueue::set_capacity(std::size_t size) {
- mutex_lock lock(mu_);
- CHECK_NE(size, 0);
- capacity_ = size;
- non_full_.notify_all();
-}
-
-// Return the capacity of the queue.
-template
-std::size_t ProducerConsumerQueue::capacity() const {
- mutex_lock lock(mu_);
- std::size_t max_elements = capacity_;
- return max_elements;
-}
-
-// Return the number of elements in the queue.
-template
-std::size_t ProducerConsumerQueue::count() const {
- mutex_lock lock(mu_);
- std::size_t num_elements = queue_.size();
- return num_elements;
-}
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc
deleted file mode 100644
index f61260c6e52756ee039829afdc7452f5f760c221..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/producer_consumer_queue.h"
-
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace {
-
-typedef ProducerConsumerQueue IntQueue;
-
-// Insert integers between low inclusive and high exclusive into q.
-void PushRange(IntQueue *q, int low, int high) {
- while (low != high) {
- q->Put(low);
- VLOG(2) << "Pushing " << low;
- ++low;
- }
-}
-
-// Push the numbers between 0 and 999 inclusive from several threads in the
-// pool.
-void PushRanges(IntQueue *queue, thread::ThreadPool *pool) {
- VLOG(1) << "Adding 20-36";
- pool->Schedule([queue] { PushRange(queue, 20, 36); });
- VLOG(1) << "Adding 7-20";
- pool->Schedule([queue] { PushRange(queue, 7, 20); });
- VLOG(1) << "Adding 36-501";
- pool->Schedule([queue] { PushRange(queue, 36, 501); });
- VLOG(1) << "Adding 501-1000";
- pool->Schedule([queue] { PushRange(queue, 501, 1000); });
- VLOG(1) << "Adding 0-5";
- pool->Schedule([queue] { PushRange(queue, 0, 5); });
- VLOG(1) << "Adding 5-7";
- pool->Schedule([queue] { PushRange(queue, 5, 7); });
-}
-
-// Pop elements from queue using Get(). Make sure that exactly elements
-// were present and their values are all integers between 0 and high-1
-// inclusive.
-void GetRange(IntQueue *queue, int high) {
- VLOG(1) << "Testing Wait";
- std::vector results;
- for (int i = 0; i != high; ++i) {
- int r = queue->Get();
- VLOG(2) << "Waited and got " << r;
- results.push_back(r);
- }
- CHECK_EQ(queue->count(), 0);
- std::sort(results.begin(), results.end());
- for (int i = 0; i != high; ++i) {
- CHECK(results[i] == i);
- }
-}
-
-// Pop elements from queue using TryGet(). Make sure that exactly
-// elements were present and their values are all integers between 0 and high-1
-// inclusive.
-void TryGetRange(IntQueue *queue, int high) {
- std::vector results;
- // Give up if we don't get all the elements back from the queue
- // in 10 seconds.
- int timeout = 10;
- int r;
- for (int i = 0; i != high; ++i) {
- while (!queue->TryGet(&r)) {
- if (!timeout--) {
- LOG(FATAL) << "Can't find all elements in the queue";
- }
- VLOG(1) << "Sleeping for a second...";
- sleep(1);
- }
- VLOG(2) << "Popped " << r;
- results.push_back(r);
- }
- CHECK_EQ(queue->count(), 0);
- CHECK(!queue->TryGet(&r));
- std::sort(results.begin(), results.end());
- for (int i = 0; i != high; ++i) {
- CHECK_EQ(i, results[i]);
- }
-}
-
-const int kNumThreads = 15;
-
-TEST(ProducerConsumerQueue, GetRange) {
- IntQueue queue;
- {
- thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
- PushRanges(&queue, &pool);
- }
- GetRange(&queue, 1000);
-}
-
-TEST(ProducerConsumerQueue, TryGetRange) {
- IntQueue queue;
- {
- thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
- PushRanges(&queue, &pool);
- }
- TryGetRange(&queue, 1000);
-}
-
-TEST(ProducerConsumerQueue, ParallelGetRange) {
- IntQueue queue;
- {
- thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
- pool.Schedule([&queue] { GetRange(&queue, 1000); });
- PushRanges(&queue, &pool);
- }
-}
-
-TEST(ProducerConsumerQueue, ParallelTryGetRange) {
- IntQueue queue;
- {
- thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
- pool.Schedule([&queue] { TryGetRange(&queue, 1000); });
- PushRanges(&queue, &pool);
- }
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile(
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Notification n;
+ Status status;
ctx->op_device_context()->CopyDeviceTensorToCPU(
&device_tensor, "ConstantArgument",
reinterpret_cast(ctx->device()), &host_tensor,
- [&](Status status) { n.Notify(); });
+ [&](Status s) {
+ status = s;
+ n.Notify();
+ });
n.WaitForNotification();
+ if (!status.ok()) {
+ LOG(ERROR) << "Copying tensor of shape "
+ << device_tensor.shape().DebugString() << " from "
+ << ctx->device()->name() << "to CPU failed with "
+ << status.ToString();
+ return status;
+ }
constant_arguments[i] = host_tensor;
}
}
@@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map variable_args = GetVariables(ctx);
std::vector args;
+
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 116e0756036e722c13f27579aa0e0876d2e846a7..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -17,8 +17,8 @@ limitations under the License.
// operators using XLA via the XLA "Host" (CPU) backend.
#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
-#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
@@ -31,13 +31,13 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector* devices) override;
+ std::vector>* devices) override;
};
-Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
- const string& name_prefix,
- std::vector* devices) {
- legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags();
+Status XlaCpuDeviceFactory::CreateDevices(
+ const SessionOptions& session_options, const string& name_prefix,
+ std::vector>* devices) {
+ XlaDeviceFlags* flags = GetXlaDeviceFlags();
bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration;
@@ -64,7 +64,18 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false;
auto device = absl::make_unique(session_options, options);
- devices->push_back(device.release());
+
+ // Setting GpuDeviceInfo because eager runtime relies on the device
+ // context in tensorflow_gpu_device_info(). Also,
+ // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test.
+ // We need XlaCpuDevice to be treated not as CPU because it allocates
+ // XlaTensors, not regular Tensors.
+ Status status = device->UseGpuDeviceInfo();
+ if (!status.ok()) {
+ errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT);
+ return status;
+ }
+ devices->push_back(std::move(device));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 5c1b55cb57f58387086ab9eaf924d0beffb43e18..4201ff91a89b1bee370e6a43337c51abe3bf974a 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
XlaDevice::~XlaDevice() {
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
mutex_lock lock(mu_);
+ while (outstanding_asynchronous_operations_ > 0) {
+ outstanding_asynchronous_operations_cv_.wait(lock);
+ }
if (device_context_) {
device_context_->Unref();
}
@@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
Status XlaDevice::Sync() {
VLOG(1) << "XlaDevice::Sync";
+ tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
std::shared_ptr stream;
{
mutex_lock lock(mu_);
@@ -391,13 +395,46 @@ Status XlaDevice::Sync() {
}
if (!stream) return Status::OK();
- if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
+ Status status = stream->BlockHostUntilDone();
+ {
+ mutex_lock lock(mu_);
+ while (outstanding_asynchronous_operations_ > 0) {
+ outstanding_asynchronous_operations_cv_.wait(lock);
+ }
+ }
+ TF_RETURN_IF_ERROR(status);
+ if (!stream->ok()) {
return errors::Internal("XlaDevice::Sync() failed.");
}
VLOG(1) << "XlaDevice::Sync completed";
return Status::OK();
}
+void XlaDevice::Sync(const DoneCallback& done) {
+ VLOG(1) << "XlaDevice::Sync (asynchronous)";
+ std::shared_ptr stream;
+ {
+ mutex_lock lock(mu_);
+ stream = stream_;
+ }
+ if (!stream) {
+ done(Status::OK());
+ return;
+ }
+
+ stream->ThenEnqueueOnBackgroundThread(
+ [this, stream, done](se::StreamExecutor*) {
+ tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
+ /*is_expensive=*/true);
+ mutex_lock lock(mu_);
+ while (outstanding_asynchronous_operations_ > 0) {
+ outstanding_asynchronous_operations_cv_.wait(lock);
+ }
+ done(stream->ok() ? Status::OK()
+ : errors::Internal("XlaDevice::Sync() failed."));
+ });
+}
+
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
@@ -441,6 +478,49 @@ bool XlaDevice::RequiresSyncOnCompletion() const {
return sync_on_completion_;
}
+XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
+ XlaDevice* device)
+ : device_(device) {
+ mutex_lock lock(device_->mu_);
+ ++device_->outstanding_asynchronous_operations_;
+}
+
+XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() {
+ if (device_) {
+ mutex_lock lock(device_->mu_);
+ --device_->outstanding_asynchronous_operations_;
+ device_->outstanding_asynchronous_operations_cv_.notify_all();
+ }
+}
+
+XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
+ const XlaDevice::AsynchronousOperationHandle& other)
+ : device_(other.device_) {
+ mutex_lock lock(device_->mu_);
+ ++device_->outstanding_asynchronous_operations_;
+}
+
+XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
+ XlaDevice::AsynchronousOperationHandle&& other)
+ : device_(other.device_) {
+ other.device_ = nullptr;
+}
+
+XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
+operator=(const XlaDevice::AsynchronousOperationHandle& other) {
+ device_ = other.device_;
+ mutex_lock lock(device_->mu_);
+ ++device_->outstanding_asynchronous_operations_;
+ return *this;
+}
+
+XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
+operator=(XlaDevice::AsynchronousOperationHandle&& other) {
+ device_ = other.device_;
+ other.device_ = nullptr;
+ return *this;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 49f53b477ef5508a23812453cb61e29a8d8b9379..c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -135,6 +135,7 @@ class XlaDevice : public LocalDevice {
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override;
+ void Sync(const DoneCallback& done) override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -164,7 +165,30 @@ class XlaDevice : public LocalDevice {
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+ // A simple RAII handle. On construction the device's
+ // outstanding_asynchronous_operations_ field is incremented; on destruction
+ // it is decremented.
+ class AsynchronousOperationHandle {
+ public:
+ AsynchronousOperationHandle(XlaDevice* device);
+ ~AsynchronousOperationHandle();
+ AsynchronousOperationHandle(const AsynchronousOperationHandle& other);
+ AsynchronousOperationHandle(AsynchronousOperationHandle&& other);
+ AsynchronousOperationHandle& operator=(
+ const AsynchronousOperationHandle& other);
+ AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other);
+
+ private:
+ XlaDevice* device_ = nullptr;
+ };
+
+ AsynchronousOperationHandle CreateAsynchronousOperationHandle() {
+ return AsynchronousOperationHandle(this);
+ }
+
private:
+ friend class AsynchronousOperationHandle;
+
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
@@ -227,6 +251,11 @@ class XlaDevice : public LocalDevice {
// True if the device requires XlaDevice::Sync to be called on completion
// regardless of status.
bool sync_on_completion_ GUARDED_BY(mu_) = false;
+
+ // Count of outstanding asynchronous operations which must be zero on Sync()
+ // completion.
+ int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0;
+ condition_variable outstanding_asynchronous_operations_cv_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel {
.HostMemory("output") \
.TypeConstraint("T"), \
ArgOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \
\
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 441970169581d53e0d8683b98d26712445b170ea..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,10 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
+#include
#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
@@ -29,12 +32,12 @@ namespace tensorflow {
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector* devices) override;
+ std::vector>* devices) override;
};
-Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
- const string& name_prefix,
- std::vector* devices) {
+Status XlaGpuDeviceFactory::CreateDevices(
+ const SessionOptions& session_options, const string& name_prefix,
+ std::vector>* devices) {
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
@@ -52,8 +55,35 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
return Status::OK();
}
-
- for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
+ string allowed_gpus =
+ session_options.config.gpu_options().visible_device_list();
+ std::set gpu_ids;
+ int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount();
+ if (allowed_gpus.empty()) {
+ for (int i = 0; i < num_visible_devices; ++i) {
+ gpu_ids.insert(i);
+ }
+ } else {
+ // For loop below is copied from gpu/gpu_device.cc. It validates
+ // the visible_device_list and populates gpu_ids set.
+ const std::vector visible_devices =
+ absl::StrSplit(allowed_gpus, ',');
+ for (const string& platform_gpu_id_str : visible_devices) {
+ int32 platform_gpu_id;
+ if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) {
+ return errors::InvalidArgument(
+ "Could not parse entry in 'visible_device_list': '",
+ platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus);
+ }
+ if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) {
+ return errors::InvalidArgument(
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
+ "' but visible device count is ", num_visible_devices);
+ }
+ gpu_ids.insert(platform_gpu_id);
+ }
+ }
+ for (int i : gpu_ids) {
XlaDevice::Options options;
options.platform = platform.ValueOrDie();
options.device_name_prefix = name_prefix;
@@ -70,7 +100,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
return status;
}
- devices->push_back(device.release());
+ devices->push_back(std::move(device));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index e828bae865d630bd40f227943cdabb2d8d95ca48..4007309ed1c57b663dca5bac0df11260bf1327f3 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -33,12 +33,12 @@ constexpr std::array kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector* devices) override;
+ std::vector>* devices) override;
};
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
- std::vector* devices) {
+ std::vector>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations;
@@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false;
- auto device = absl::make_unique(session_options, options);
- devices->push_back(device.release());
+ devices->push_back(absl::make_unique(session_options, options));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer {
public:
XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size,
Allocator* allocator)
- : expected_size_(expected_size),
+ : TensorBuffer(const_cast(ptr)),
+ expected_size_(expected_size),
actual_size_(actual_size),
- allocator_(allocator) {
- data_ = const_cast(ptr);
- }
+ allocator_(allocator) {}
~XlaTensorBuffer() override {
- if (data_) {
- allocator_->DeallocateRaw(data_);
+ if (data()) {
+ allocator_->DeallocateRaw(data());
}
}
- void* data() const override { return data_; }
size_t size() const override { return expected_size_; }
TensorBuffer* root_buffer() override { return this; }
@@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer {
}
private:
- void* data_;
size_t expected_size_;
size_t actual_size_;
Allocator* allocator_;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 6b8e6bba1e1bbfd773141d33721e4d7e30420a11..093b61629cd0b04d5d8488139b8d7262b739f86d 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -375,27 +375,6 @@ tf_xla_py_test(
],
)
-tf_xla_py_test(
- name = "resampler_ops_test",
- size = "small",
- srcs = ["resampler_ops_test.py"],
- disabled_backends = [
- # TODO(b/74459949) Support BatchDot in CPU backend.
- "cpu",
- "cpu_ondemand",
- ],
- # TODO(b/112295522): figure out how to make OSS build pass.
- tags = ["no_oss"],
- deps = [
- ":xla_test",
- "//tensorflow/contrib/resampler:resampler_ops",
- "//tensorflow/contrib/resampler:resampler_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:platform_test",
- ],
-)
-
tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
@@ -429,13 +408,6 @@ tf_xla_py_test(
name = "eager_test",
size = "large",
srcs = ["eager_test.py"],
- disabled_backends = [
- # TODO(b/78199195) Support XLA CPU devices in eager runtime
- "cpu",
- "cpu_ondemand",
- # TODO(b/78468222) Enable GPU backend
- "gpu",
- ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -474,7 +446,6 @@ tf_xla_py_test(
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
- "//tensorflow/python:spectral_ops",
"//tensorflow/python/ops/signal",
],
)
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
index 69fb3ec2964a09508e612515b9e291fc14121d68..e9c2d363acab96c0fb968cb7f901ce105ea8703e 100644
--- a/tensorflow/compiler/tests/adagrad_da_test.py
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run a step of AdagradDA
update.run()
@@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
# For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
# similarly for others.
self.assertAllCloseAccordingToType(
- np.array([-0.904534, -1.603567]), var0.eval())
+ np.array([-0.904534, -1.603567]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([-0.094821, -0.189358]), var1.eval())
+ np.array([-0.094821, -0.189358]), self.evaluate(var1))
def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types:
@@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
- self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
+ self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA
update.run()
self.assertAllCloseAccordingToType(
- np.array([-0.904534, -1.603567]), var0.eval())
+ np.array([-0.904534, -1.603567]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([-0.094821, -0.189358]), var1.eval())
+ np.array([-0.094821, -0.189358]), self.evaluate(var1))
def testAdagradDAWithL1(self):
for dtype in self.float_types:
@@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
- self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
+ self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA
update.run()
self.assertAllCloseAccordingToType(
- np.array([-0.895489, -1.59555]), var0.eval())
+ np.array([-0.895489, -1.59555]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([-0.085339, -0.17989]), var1.eval())
+ np.array([-0.085339, -0.17989]), self.evaluate(var1))
def testAdagradDAWithL1_L2(self):
for dtype in self.float_types:
@@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
- self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
+ self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA
update.run()
self.assertAllCloseAccordingToType(
- np.array([-0.046907, -0.093659]), var0.eval())
+ np.array([-0.046907, -0.093659]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([-0.004275, -0.009023]), var1.eval())
+ np.array([-0.004275, -0.009023]), self.evaluate(var1))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index ab69319c59fb07e7ce56c3c287a50a6290effdfd..e26483303c3934fd51675cb1fbc998b276caf527 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of adagrad
for _ in range(3):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ np.array([-1.6026098728179932, -0.6026098728179932]),
+ self.evaluate(var0),
float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ np.array([2.715679168701172, 3.715679168701172]),
+ self.evaluate(var1),
float_rtol=1e-5)
def testTensorLearningRate(self):
@@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of adagrad
for _ in range(3):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ np.array([-1.6026098728179932, -0.6026098728179932]),
+ self.evaluate(var0),
float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ np.array([2.715679168701172, 3.715679168701172]),
+ self.evaluate(var1),
float_rtol=1e-5)
def testSharing(self):
@@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values.
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Mix the first and the second adagrad for 3 steps.
ada_update1.run()
ada_update2.run()
ada_update1.run()
# Validate updated params (the same as with only 1 Adagrad).
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ np.array([-1.6026098728179932, -0.6026098728179932]),
+ self.evaluate(var0),
float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ np.array([2.715679168701172, 3.715679168701172]),
+ self.evaluate(var1),
float_rtol=1e-5)
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 058576b3d4b695209952158769162bb24e7ccfce..8bcff9d379d34f8a6bb8b0fdc60b7588c6d80be9 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power, beta2_power = opt._get_beta_accumulators()
# Run 3 steps of Adam
for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**t,
+ self.evaluate(beta2_power))
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testTensorLearningRate(self):
for dtype in self.float_types:
@@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power, beta2_power = opt._get_beta_accumulators()
# Run 3 steps of Adam
for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**t,
+ self.evaluate(beta2_power))
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testSharing(self):
for dtype in self.float_types:
@@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase):
beta1_power, beta2_power = opt._get_beta_accumulators()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of intertwined Adam1 and Adam2.
for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**t,
+ self.evaluate(beta2_power))
if t % 2 == 0:
update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
else:
@@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase):
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
index 3ed1d41b7121f44dd7470f61180f7a7055369174..961b46375c941bdc3922e460a2f58345086dbceb 100644
--- a/tensorflow/compiler/tests/adamax_test.py
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power = opt._get_beta_accumulators()
@@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
for t in range(1, 4):
update.run()
- self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
- self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
+ self.assertAllCloseAccordingToType(
+ var0_np, self.evaluate(var0), rtol=1e-2)
+ self.assertAllCloseAccordingToType(
+ var1_np, self.evaluate(var1), rtol=1e-2)
self.assertEqual("var0_%d/AdaMax:0" % (i,),
opt.get_slot(var=var0, name="m").name)
@@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power = opt._get_beta_accumulators()
# Run 3 steps of AdaMax
for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
update.run()
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
index 1bc07ace23ccdc83103abe71ee11b72994c75a6d..a37c97e6d374440aeb860b9d02f2d5dd95c91f62 100644
--- a/tensorflow/compiler/tests/addsign_test.py
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 7 steps of AddSign
# first 4 steps with positive gradient
@@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- var0_np, var0.eval(), half_rtol=1e-2)
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ var0_np, self.evaluate(var0), half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testDense(self):
decay_steps = 10
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 332381c59eed06d5697e58efb1d8fa2b6ef604d2..9a5423c1b2a5df7880453cbb328f6a8174066255 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -218,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase):
],
equality_test=self.ListsAreClose)
+ # TF doesn't define these for bf16.
+ if dtype != dtypes.bfloat16.as_numpy_dtype:
+ self._testBinary(
+ gen_math_ops.xdivy,
+ np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
+ np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
+ expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops.xlogy,
+ np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
+ np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
+ expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0],
+ dtype=dtype))
+
def testIntOps(self):
for dtype in self.signed_int_types:
self._testBinary(
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index a57d1dc81ea2c9c188b0a3005904738aa8156bf3..5d5e486f616937601214aa169a4c329ab78932c8 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest
@@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase):
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
- with self.cached_session() as sess, self.test_scope():
+ with self.cached_session(), self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
- d = sess.run(op)
+ d = self.evaluate(op)
batch_size, num_classes = logits.shape
freqs_mat = []
@@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
x = rng(dtype, output_dtype)
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
- y = sess.run(x)
- z = sess.run(x)
- w = sess.run(x)
+ y = self.evaluate(x)
+ z = self.evaluate(x)
+ w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
output_dtype=output_dtype)
- y = sess.run(x)
+ y = self.evaluate(x)
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)
@@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase):
chi2 = self._chi2(probs, freqs)
self.assertLess(chi2, 1e-3)
+ def testStatelessMultinomialIsInRange(self):
+ for dtype in self.float_types:
+ for output_dtype in self.output_dtypes():
+ with self.cached_session() as sess:
+ with self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless_random_ops.stateless_multinomial(
+ array_ops.ones(shape=[1, 20], dtype=dtype),
+ 1000,
+ seed_t,
+ output_dtype=output_dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertTrue((y >= 0).sum() == 1000)
+ self.assertTrue((y < 20).sum() == 1000)
+
+ def testDeterminismMultinomial(self):
+ # Stateless values should be equal iff the seeds are equal (roughly)
+ num_samples = 10
+ with self.cached_session(), self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ seeds = [(x, y) for x in range(5) for y in range(5)] * 3
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ pure = stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed=seed_t)
+ values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
+ for s0, v0 in values:
+ for s1, v1 in values:
+ self.assertEqual(s0 == s1, np.all(v0 == v1))
+
+ def testEmpty(self):
+ with self.cached_session():
+ with self.test_scope():
+ x = random_ops.multinomial(
+ array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
+ y = self.evaluate(x)
+ self.assertEqual(y.shape, (42, 0))
+
+ def testEmptyStateless(self):
+ with self.cached_session() as sess:
+ with self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless_random_ops.stateless_multinomial(
+ array_ops.zeros([42, 40]),
+ 0,
+ seed=seed_t,
+ output_dtype=dtypes.int32)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertEqual(y.shape, (42, 0))
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
index 88bd58b2da6b2892f898ad10f3467d8ce39d6388..ef2d7af69deeebd5f4c4c7225d7027f8f76bf861 100644
--- a/tensorflow/compiler/tests/clustering_test.py
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase):
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
output = math_ops.add(input1, input2)
- result = output.eval()
+ result = self.evaluate(output)
self.assertAllClose(result, expected, rtol=1e-3)
def testAddFromCpuMultiple(self):
@@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase):
with self.test_scope():
output = math_ops.add(input1, input2)
for _ in xrange(10):
- result = output.eval()
+ result = self.evaluate(output)
self.assertAllClose(result, expected, rtol=1e-3)
def testDeadlock(self):
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 2d225ad226cac368042b95eae8fc29e6fd8e82e0..2187f57960f80300d631bdc7eb8fe5e9c8dddeea 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase):
x2 = constant_op.constant(p2)
with self.test_scope():
c = array_ops.concat([x1, x2], 0)
- result = c.eval()
+ result = self.evaluate(c)
self.assertAllEqual(result[:2, :], p1)
self.assertAllEqual(result[2:, :], p2)
@@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 1)
- result = concated_grad.eval()
+ result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp)
def testGradientsSimpleAll(self):
@@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 0)
- result = concated_grad.eval()
+ result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp)
@@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 2)
- result = concated_grad.eval()
+ result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp)
@@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, concat_dim)
- result = concated_grad.eval()
+ result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp)
@@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
@@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
- dxs = sess.run(gradients_impl.gradients(c, xs, dc))
+ dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testConcatTuple(self):
@@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase):
with self.test_scope():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
- self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
+ self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t))
def testConcatNoScalars(self):
with self.cached_session():
@@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
- ans = sess.run(off)
+ ans = self.evaluate(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
class PackTest(xla_test.XLATestCase):
def testBasic(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
- ans = sess.run(packed)
+ ans = self.evaluate(packed)
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
s2 = constant_op.constant(5, dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
- ans = sess.run(packed)
+ ans = self.evaluate(packed)
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
s2 = constant_op.constant([[]], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
- ans = sess.run(packed)
+ ans = self.evaluate(packed)
self.assertAllEqual(ans, [[[]], [[]], [[]]])
diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py
index d59fd0236f4f7da2bbfb3409342c7f70f8f5d1f6..01cc1b6392845be2418c50d55be97487eb290843 100644
--- a/tensorflow/compiler/tests/conv3d_test.py
+++ b/tensorflow/compiler/tests/conv3d_test.py
@@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
- value = output.eval()
+ value = self.evaluate(output)
# We count the number of cells being added at the locations in the output.
# At the center, #cells = kernel_depth * kernel_height * kernel_width
@@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
- value = output.eval()
+ value = self.evaluate(output)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[3]):
@@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="VALID")
- value = output.eval()
+ value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32)
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index d1b90f098d7d6574999ba0af44b285f5ad5e4f8d..bf5ea7b1fb6fb3c774c4db20d059f131990d20d3 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata):
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
- return any([substr in x for x in labels])
+ return any(substr in x for x in labels)
class DenseLayerTest(test.TestCase):
@@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase):
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope():
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope():
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index 50b04daa6b9f4159a3c4bdeecaf900a5b35a833c..e89cf975f5d889091ce92a35165aef55ee5ad4b0 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase):
[idx1, idx2], [val1, val2],
expected=np.array([[], [], [], []], np.int32))
+ def testEmptyIndex(self):
+ idx1 = np.array([], dtype=np.int32)
+ idx2 = np.array([[], []], dtype=np.int32)
+ val1 = np.ndarray(shape=(0, 9), dtype=np.int32)
+ val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32)
+ self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2],
+ expected=np.ndarray(
+ shape=(0, 9), dtype=np.int32))
+
def testSimple1D(self):
val1 = np.array([0, 4, 7], dtype=np.int32)
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 63cee550fde9d9d4314b1541fba191df776a4da2..2af32b537ba53723370faf81aebf308a465718c7 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product)
# Run some ops graphly
- with context.graph_mode(), self.cached_session() as sess:
+ with context.graph_mode(), self.cached_session():
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
- self.assertAllEqual(15, sess.run(product))
+ self.assertAllEqual(15, self.evaluate(product))
def testDegenerateSlices(self):
with self.test_scope():
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index e92afd5d6feb42ece233ee521e3a796c4bc3914a..0edd0c35aa2d417a3ed24decbaa0b5d62d35bb62 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -27,8 +27,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import signal
-from tensorflow.python.ops import spectral_ops
+from tensorflow.python.ops.signal import signal
from tensorflow.python.platform import googletest
BATCH_DIMS = (3, 5)
@@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase):
def testFFT(self):
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
- spectral_ops.fft)
+ signal.fft)
def testFFT2D(self):
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
- spectral_ops.fft2d)
+ signal.fft2d)
def testFFT3D(self):
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
- spectral_ops.fft3d)
+ signal.fft3d)
def testIFFT(self):
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
- spectral_ops.ifft)
+ signal.ifft)
def testIFFT2D(self):
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
- spectral_ops.ifft2d)
+ signal.ifft2d)
def testIFFT3D(self):
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
- spectral_ops.ifft3d)
+ signal.ifft3d)
def testRFFT(self):
self._VerifyFftMethod(
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
- lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value]))
+ lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
def testRFFT2D(self):
def _tf_fn(x):
- return spectral_ops.rfft2d(
+ return signal.rfft2d(
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
self._VerifyFftMethod(
@@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase):
x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
def _tf_fn(x):
- return spectral_ops.rfft3d(
+ return signal.rfft3d(
x,
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
+ def testRFFT3DMismatchedSize(self):
+
+ def _to_expected(x):
+ return np.fft.rfftn(
+ x,
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _tf_fn(x):
+ return signal.rfft3d(
+ x,
+ fft_length=[
+ x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
+ ])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
+
def testIRFFT(self):
def _tf_fn(x):
- return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
+ return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
self._VerifyFftMethod(
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
@@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase):
def testIRFFT2D(self):
def _tf_fn(x):
- return spectral_ops.irfft2d(
+ return signal.irfft2d(
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
self._VerifyFftMethod(
@@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase):
s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
def _tf_fn(x):
- return spectral_ops.irfft3d(
+ return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
@@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase):
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
+ def testIRFFT3DMismatchedSize(self):
+
+ def _to_input(x):
+ return np.fft.rfftn(
+ np.real(x),
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _to_expected(x):
+ return np.fft.irfftn(
+ x,
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _tf_fn(x):
+ return signal.irfft3d(
+ x,
+ fft_length=[
+ x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
+ ])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
index 8c7edfd277c992c35a81dd5f261256a86352254e..91d77d2f791834346f43aecb60d116ddbf2faa6e 100644
--- a/tensorflow/compiler/tests/fifo_queue_test.py
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
enqueue_op.run()
for i in xrange(len(elems)):
- vals = dequeued_t.eval()
+ vals = self.evaluate(dequeued_t)
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
@@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([], size.get_shape())
enqueue_op.run()
- self.assertEqual(1, size.eval())
+ self.assertEqual(1, self.evaluate(size))
dequeued_t.op.run()
- self.assertEqual(0, size.eval())
+ self.assertEqual(0, self.evaluate(size))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 5b197afd655404e4e36a8b3442f8db60cb1d648d..b078053cdbd6d129645734492d34dd25d28ab3ef 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Ftrl for a few steps
for _ in range(steps):
ftrl_update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def equivAdagradTest_AdagradPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Adagrad for a few steps
for _ in range(steps):
adagrad_update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def equivGradientDescentTest_FtrlPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Ftrl for a few steps
for _ in range(steps):
ftrl_update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def equivGradientDescentTest_GradientDescentPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run GradientDescent for a few steps
for _ in range(steps):
sgd_update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def testFtrlwithoutRegularization(self):
for dtype in self.float_types:
@@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run 3 steps FTRL
for _ in range(3):
@@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
np.array([-2.60260963, -4.29698515]),
- var0.eval(),
+ self.evaluate(var0),
float_rtol=1e-4,
half_rtol=1e-2)
self.assertAllCloseAccordingToType(
np.array([-0.28432083, -0.56694895]),
- var1.eval(),
+ self.evaluate(var1),
float_rtol=1e-5,
half_rtol=1e-2)
@@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 3 steps FTRL
for _ in range(3):
@@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5,
+ np.array([-2.55607247, -3.98729396]),
+ self.evaluate(var0),
+ 1e-5,
+ 1e-5,
float_rtol=1e-4)
self.assertAllCloseAccordingToType(
- np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5)
+ np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5,
+ 1e-5)
def testFtrlWithL1(self):
for dtype in self.float_types:
@@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL
for _ in range(10):
@@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
np.array([-7.66718769, -10.91273689]),
- var0.eval(),
+ self.evaluate(var0),
rtol=1e-4,
bfloat16_rtol=1e-1,
bfloat16_atol=1e-1)
self.assertAllCloseAccordingToType(
- np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4)
+ np.array([-0.93460727, -1.86147261]),
+ self.evaluate(var1),
+ rtol=1e-4)
def testFtrlWithL1_L2(self):
for dtype in self.float_types:
@@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL
for _ in range(10):
@@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5)
+ np.array([-0.24059935, -0.46829352]),
+ self.evaluate(var0),
+ rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5)
+ np.array([-0.02406147, -0.04830509]),
+ self.evaluate(var1),
+ rtol=1e-5)
def testFtrlWithL1_L2_L2Shrinkage(self):
"""Test the new FTRL op with support for l2 shrinkage.
@@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
- self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
+ self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL
for _ in range(10):
@@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
+ np.array([-0.22578996, -0.44345799]),
+ self.evaluate(var0),
+ rtol=1e-4)
self.assertAllCloseAccordingToType(
- np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
+ np.array([-0.14378493, -0.13229476]),
+ self.evaluate(var1),
+ rtol=1e-4)
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
@@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
update1 = opt1.apply_gradients([(grads1, var1)])
variables.global_variables_initializer().run()
- self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
- self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
+ self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1))
# Run 10 steps FTRL
for _ in range(10):
@@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# var0 is experiencing L2 shrinkage so it should be smaller than var1
# in magnitude.
- self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
+ self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all())
accum0 = list(opt0._slots["accum"].values())[0].eval()
accum1 = list(opt1._slots["accum"].values())[0].eval()
# L2 shrinkage should not change how we update grad accumulator.
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index b1891b918c6584abce9da382088ed0037f5319fb..a61827c2ae44de117abad5b7db5c6bcd78fa171e 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
- result = sess.run(call_f)
+ result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testNestedFunctions(self):
@@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_g = Foo(a, b)
- result = sess.run(call_g)
+ result = self.evaluate(call_g)
self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionMultipleRetvals(self):
@@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
- result = sess.run(call_f)
+ result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testCompileTimeConstantsInDefun(self):
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6f51ae33a1b0fc8670ddf0cacb03a3b5a9176a91..dbea9849e217519874352b789588a2af62f1c826 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -75,7 +75,7 @@ def RunMetadataLabels(run_metadata):
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
- return any([substr in x for x in labels])
+ return any(substr in x for x in labels)
def MetadataHasXlaRunOp(run_metadata):
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
index 58622114e4f552fb71db9b040a39b57d7da0037c..0210201fa71a6e790e94667073ab4dba542537a5 100644
--- a/tensorflow/compiler/tests/listdiff_op_test.py
+++ b/tensorflow/compiler/tests/listdiff_op_test.py
@@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.cached_session() as sess:
+ with self.cached_session():
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
out_tensor, idx_tensor = array_ops.listdiff(
x_tensor, y_tensor, out_idx=index_dtype)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
self.assertAllEqual(out, tf_out)
self.assertAllEqual(idx, tf_idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
index c6ad67993e8bc196a74c9a328df8c9200c92c575..5dddf6ae4e8c8a3d5e9eb7b2c62298df02a0093c 100644
--- a/tensorflow/compiler/tests/lrn_ops_test.py
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase):
with self.test_scope():
actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image,
depth_radius, bias, alpha, beta)
- expected_val = expected.eval()
- actual_val = actual.eval()
+ expected_val = self.evaluate(expected)
+ actual_val = self.evaluate(actual)
self.assertAllClose(actual_val, expected_val, rtol=1e-3)
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
index 265c0b6d1412de7be3a5bf5e79129cb330ceb162..776ed899e68ddd3893b8bb30b7c8034297aa6515 100644
--- a/tensorflow/compiler/tests/lstm_test.py
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -88,8 +88,8 @@ class LSTMTest(test.TestCase):
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM step.
- sess.run(variables.global_variables_initializer())
- return sess.run([m, c])
+ self.evaluate(variables.global_variables_initializer())
+ return self.evaluate([m, c])
def testLSTMCell(self):
# Run with all-0 weights, no padding.
@@ -173,8 +173,8 @@ class LSTMTest(test.TestCase):
(basename, m_init_scalar, c_init_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM layer.
- sess.run(variables.global_variables_initializer())
- return sess.run(out_seq)
+ self.evaluate(variables.global_variables_initializer())
+ return self.evaluate(out_seq)
def testLSTMLayer(self):
# Run with all-0 weights, no padding.
diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py
index f77521a7c49dba39849869ddceb7c0e885147722..3416f7dbd6bdd264bf79785084f981f5b07cb8a9 100644
--- a/tensorflow/compiler/tests/momentum_test.py
+++ b/tensorflow/compiler/tests/momentum_test.py
@@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
self.assertFalse(slot1 in variables.trainable_variables())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate
mom_update.run()
# Check that the momentum accumulators have been updated.
- self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
- self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.1, 0.1]), self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
- np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update.
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
- ]), var0.eval())
+ ]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([
- 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
- (0.9 * 0.01 + 0.01) * 2.0)
- ]), var1.eval())
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
def testNesterovMomentum(self):
for dtype in self.float_types:
@@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy(
var1_np, accum1_np, 0.9, 0.1, 0.9)
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testTensorLearningRateAndMomentum(self):
for dtype in self.float_types:
@@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
self.assertFalse(slot1 in variables.trainable_variables())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate
mom_update.run()
# Check that the momentum accumulators have been updated.
- self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
- self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.1, 0.1]), self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
- np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
self.assertAllCloseAccordingToType(
- np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update.
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
- ]), var0.eval())
+ ]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([
- 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
- (0.9 * 0.01 + 0.01) * 2.0)
- ]), var1.eval())
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
index 77bb839409f0c323ff6ed2c8d6bd105d3003b398..9671ae0ae973ff82d22744a1feb9b4293d94bbdd 100644
--- a/tensorflow/compiler/tests/placeholder_test.py
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
sess.run(variables.variables_initializer([v]))
- self.assertEqual(8.0, sess.run(out))
+ self.assertEqual(8.0, self.evaluate(out))
def test_placeholder_with_default_fed(self):
with self.cached_session() as sess, self.test_scope():
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
index 86536da7fed0e2309beb32fee9c7c605491592ed..5b35c20027700b34500a31e174061d7087094b61 100644
--- a/tensorflow/compiler/tests/powersign_test.py
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -91,8 +91,8 @@ class PowerSignTest(xla_test.XLATestCase):
variables.global_variables_initializer().run()
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 7 steps of powersign
# first 4 steps with positive gradient
@@ -125,8 +125,8 @@ class PowerSignTest(xla_test.XLATestCase):
)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testDense(self):
decay_steps = 10
diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py
index c41b4171e26af4f7ad0237d7407a5b3691299595..63cc51a470164915b2614a06d18ca1850bb64a3c 100644
--- a/tensorflow/compiler/tests/proximal_adagrad_test.py
+++ b/tensorflow/compiler/tests/proximal_adagrad_test.py
@@ -45,15 +45,17 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run 3 steps Proximal Adagrad.
for _ in range(3):
update.run()
- self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval())
- self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval())
+ self.assertAllClose(
+ np.array([-2.60260963, -4.29698515]), self.evaluate(var0))
+ self.assertAllClose(
+ np.array([-0.28432083, -0.56694895]), self.evaluate(var1))
opt_vars = opt.variables()
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
self.assertStartsWith(opt_vars[1].name, var1._shared_name)
@@ -74,14 +76,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 3 steps Proximal Adagrad.
for _ in range(3):
update.run()
- self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval())
- self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval())
+ self.assertAllClose(np.array([-1.60261, -2.296985]), self.evaluate(var0))
+ self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1))
def testProximalAdagradWithL1(self):
with self.cached_session(), self.test_scope():
@@ -98,14 +100,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps Proximal Adagrad
for _ in range(10):
update.run()
- self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval())
- self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval())
+ self.assertAllClose(np.array([-6.663634, -9.190331]), self.evaluate(var0))
+ self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1))
def testProximalAdagradWithL1_L2(self):
with self.cached_session(), self.test_scope():
@@ -122,15 +124,15 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps Proximal Adagrad.
for _ in range(10):
update.run()
- self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
- self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
+ self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0))
+ self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1))
def applyOptimizer(self, opt, steps=5):
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
@@ -141,14 +143,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run ProximalAdagrad for a few steps
for _ in range(steps):
update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def testEquivAdagradwithoutRegularization(self):
with self.cached_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
index 3d808e6b8a71ef9fa60b671d07bfd907e9f58efc..5aec433be765dd0a04bd7ab10d5c39a5a7f48c5c 100644
--- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py
+++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
@@ -42,15 +42,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([0.0, 0.0], var0.eval())
- self.assertAllClose([0.0, 0.0], var1.eval())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var0))
+ self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run 3 steps Proximal Gradient Descent.
for _ in range(3):
update.run()
- self.assertAllClose(np.array([-0.9, -1.8]), var0.eval())
- self.assertAllClose(np.array([-0.09, -0.18]), var1.eval())
+ self.assertAllClose(np.array([-0.9, -1.8]), self.evaluate(var0))
+ self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1))
def testProximalGradientDescentwithoutRegularization2(self):
with self.cached_session(), self.test_scope():
@@ -64,15 +64,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 3 steps Proximal Gradient Descent
for _ in range(3):
update.run()
- self.assertAllClose(np.array([0.1, 0.2]), var0.eval())
- self.assertAllClose(np.array([3.91, 2.82]), var1.eval())
+ self.assertAllClose(np.array([0.1, 0.2]), self.evaluate(var0))
+ self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1))
def testProximalGradientDescentWithL1(self):
with self.cached_session(), self.test_scope():
@@ -86,15 +86,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps proximal gradient descent.
for _ in range(10):
update.run()
- self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval())
- self.assertAllClose(np.array([3.67, 2.37]), var1.eval())
+ self.assertAllClose(np.array([-1.988, -3.988001]), self.evaluate(var0))
+ self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1))
def testProximalGradientDescentWithL1_L2(self):
with self.cached_session(), self.test_scope():
@@ -108,15 +108,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([4.0, 3.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps Proximal Gradient Descent
for _ in range(10):
update.run()
- self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
- self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
+ self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0))
+ self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1))
def applyOptimizer(self, opt, steps=5):
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
@@ -127,14 +127,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run ProximalAdagrad for a few steps
for _ in range(steps):
update.run()
- return var0.eval(), var1.eval()
+ return self.evaluate(var0), self.evaluate(var1)
def testEquivGradientDescentwithoutRegularization(self):
with self.cached_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 236b1b881dcaffc1a5b0c6395f0605c1d7ef0269..b4d4193e35f9e0e3b23d0242ed076dd811f4ee2b 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -63,7 +63,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx = math_ops.matmul(x, x, adjoint_a=True)
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
- precision = self.AdjustedNorm(xx.eval() - identity.eval())
+ precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity))
self.assertTrue(np.all(precision < 5.0))
def _test(self, dtype, shape, full_matrices):
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206..97ffad34c00b8ec16eb1ec109ba5d980e0ce673d 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
- y = sess.run(x)
- z = sess.run(x)
- w = sess.run(x)
+ y = self.evaluate(x)
+ z = self.evaluate(x)
+ w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
- y = sess.run(x)
+ y = self.evaluate(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
@@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
- y = sess.run(x)
+ y = self.evaluate(x)
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
@@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
- return sess.run(special_math.ndtri(x))
+ return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
@@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
- result = sess.run(shuffle)
+ result = self.evaluate(shuffle)
expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
@@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
- result = sess.run(shuffle)
+ result = self.evaluate(shuffle)
expected = np.diag(range(20)).flatten()
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index a6b58020126a3297944f199e99b0801387615564..d23fd125163d1afe8c7fd5e008d4b617ff4b2874 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -3382,10 +3382,10 @@ int main(int argc, char** argv) {
}
// XLA devices register kernels at construction time; create all known devices
// to make sure the kernels are registered.
- std::vector devices;
+ std::vector> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "", &devices));
- tensorflow::DeviceMgr device_mgr(devices);
+ tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::Device* ignored;
TF_QCHECK_OK(
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 132c59c32c9db0c8759bdbb31f8613c3ef88b485..e8fc81bbb5472669c408b8bbdbcdfcdcf461131f 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -91,6 +91,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
np.array([], dtype=np.bool).reshape(0, 3),
np.array([[False, True, False], [True, True, False]]),
]
+ ONES = [np.ones([34000, 2])]
def testReduceSumF32(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
@@ -149,6 +150,11 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
self.NONEMPTY_REAL_DATA, index_dtype)
+ def testReduceMeanF16(self, index_dtype):
+ if np.float16 in self.all_types:
+ self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES,
+ index_dtype)
+
def testReduceMeanC64(self, index_dtype):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
self.NONEMPTY_COMPLEX_DATA, index_dtype)
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index 8840a1329a907bddc6ef1cb6dd1c2a6d234def5c..dc3e90b4afa41c08d899ee195d42fb91678bad1c 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -76,7 +76,7 @@ class RmspropTest(xla_test.XLATestCase):
rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
rms_update = rms_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
mg0 = rms_opt.get_slot(var0, "mg")
self.assertEqual(mg0 is not None, centered)
@@ -92,12 +92,12 @@ class RmspropTest(xla_test.XLATestCase):
self.assertTrue(mom1 is not None)
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of RMSProp
for _ in range(3):
- rms_update.run()
+ self.evaluate(rms_update)
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
var0_np,
@@ -118,14 +118,14 @@ class RmspropTest(xla_test.XLATestCase):
# Validate updated params
if centered:
- self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
- self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
- self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
- self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
- self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
- self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0))
+ self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1))
+ self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
+ self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
+ self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
+ self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 897db384b7e8067b0460b5f344201f101a4d8479..17639bd8a755b9e9f5acc77979ac7a4149f112db 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse):
class CumsumTest(xla_test.XLATestCase):
- valid_dtypes = [np.float32]
+ valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
@@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase):
class CumprodTest(xla_test.XLATestCase):
- valid_dtypes = [np.float32]
+ valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 21708aa15877647e2a979a5a2674dfb734700df3..ee7ca7e6f196e114ff18e2597145e5c198980b08 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -156,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
- return sess.run(special_math.ndtri(x))
+ return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 46ca371c8abf1cb4710717a183ee12820c4c4ca0..d7e26d79c4c054860ade5c8960a3bca984e020b0 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -79,7 +79,8 @@ class TensorArrayTest(xla_test.XLATestCase):
c0 = w2.stack()
self.assertAllEqual(
- convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
+ convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]),
+ self.evaluate(c0))
def testTensorArrayWritePack(self):
for dtype in self.numeric_tf_types:
@@ -97,7 +98,7 @@ class TensorArrayTest(xla_test.XLATestCase):
c0 = w2.stack()
- self.assertAllEqual([3, 0, 1], c0.eval().shape)
+ self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape)
def _testTensorArrayWriteConcat(self, tf_dtype):
with self.cached_session(), self.test_scope():
@@ -113,8 +114,8 @@ class TensorArrayTest(xla_test.XLATestCase):
c0 = w2.concat()
self.assertAllEqual(
- convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0],
- [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval())
+ convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0],
+ [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0))
def testTensorArrayWriteConcat(self):
for dtype in self.numeric_tf_types:
@@ -341,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow)
with self.assertRaisesOpError("TensorArray dtype is "):
- r0_bad.eval()
+ self.evaluate(r0_bad)
# Test reading from a different index than the one we wrote to
w0.read(1)
@@ -422,7 +423,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w2 = h2.write(0, 5.0)
r2 = w2.read(0)
r = r1 + r2
- self.assertAllClose(9.0, r.eval())
+ self.assertAllClose(9.0, self.evaluate(r))
def _testTensorArrayGradientWriteReadType(self, dtype):
with self.cached_session() as session, self.test_scope():
@@ -504,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
[-0.5, 1.5], # read(0) gradient
[20.0, 30.0, 40.0, 50.0], # concat gradient
])
- grad_vals = sess.run(grad_r) # 2 + 2 entries
+ grad_vals = self.evaluate(grad_r) # 2 + 2 entries
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
@@ -526,7 +527,7 @@ class TensorArrayTest(xla_test.XLATestCase):
with ops.control_dependencies([r0_readtwice]):
r1_readtwice = w_readtwice.read(0)
- self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
+ self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice))
def _testTensorArrayGradientUnpackRead(self):
with self.cached_session() as session, self.test_scope():
@@ -592,7 +593,7 @@ class TensorArrayTest(xla_test.XLATestCase):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
s = ta.size()
- self.assertAllEqual(3, s.eval())
+ self.assertAllEqual(3, self.evaluate(s))
def testWriteCloseTensorArray(self):
with self.cached_session(), self.test_scope():
@@ -722,7 +723,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# r = acc2.stack()
# grad = gradients_impl.gradients(r, [x])[0]
- # self.assertAllClose(31.0, grad.eval())
+ # self.assertAllClose(31.0, self.evaluate(grad))
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
with self.cached_session() as session, self.test_scope():
@@ -912,7 +913,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertEqual(0, ta.size().eval())
ta = ta.unstack(array_ops.zeros([0, 3, 5]))
packed = ta.stack()
- self.assertAllEqual([0, 3, 5], packed.eval().shape)
+ self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape)
# Concatenating zero tensors along their first dimension gives a
# first dimension of zero
self.assertAllEqual([0, 5], ta.concat().eval().shape)
@@ -1041,8 +1042,8 @@ class TensorArrayTest(xla_test.XLATestCase):
(read0, read1, size0, size1))
# Tests that the control dependencies was added and executed.
- self.assertEqual(1, v0.eval())
- self.assertEqual(1, v1.eval())
+ self.assertEqual(1, self.evaluate(v0))
+ self.assertEqual(1, self.evaluate(v1))
# Tests correct TensorArray.
self.assertEqual(read0_v, 0)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index d612d3b32dd6b0893508413b337ea9ad95ef6dd7..95c9e7ffd4651642781143c2c1940b0e51e1e470 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -481,6 +481,72 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
+ def quantize_and_dequantize_v2_round_half_up(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x,
+ -1,
+ 1.0,
+ signed_input=True,
+ num_bits=8,
+ range_given=True,
+ round_mode="HALF_UP")
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v2_round_half_up,
+ np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype),
+ expected=np.array([
+ -102.0 / 127,
+ -63.0 / 127,
+ 0,
+ 38.0 / 127,
+ 102.0 / 127,
+ -128.0 / 127,
+ 1,
+ ],
+ dtype=dtype))
+
+ def quantize_and_dequantize_v2_round_half_to_even(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x,
+ -1.0,
+ 1.0,
+ signed_input=True,
+ num_bits=8,
+ range_given=True,
+ round_mode="HALF_TO_EVEN")
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v2_round_half_to_even,
+ np.array(
+ [
+ -0.8,
+ # The -0.5 should become -63.5 after scaling and with
+ # rounding this should become -64. But with the test
+ # unary_ops_test_cpu_ondemand, this fails as the result
+ # before scaling becomes -63.499996 and gets rounded to -63.
+ # TODO(sreenik): Some one more familiar with this test needs
+ # to take a look and resolve this. This works on all other
+ # variations of the platform like cpu, and gpu.
+ # -0.5,
+ 0,
+ 0.3,
+ 0.8,
+ -2,
+ 33
+ ],
+ dtype=dtype),
+ expected=np.array(
+ [
+ -102.0 / 127,
+ # -64.0 / 127,
+ 0,
+ 38.0 / 127,
+ 102.0 / 127,
+ -128.0 / 127,
+ 1,
+ ],
+ dtype=dtype))
+
def quantize_and_dequantize_v3(x):
return array_ops.quantize_and_dequantize_v3(
x, -127, 127, num_bits=8, signed_input=True, range_given=False)
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index 77cdeac8168aa71555955b141852587d62ab59d3..fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -77,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase):
sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2)
self.assertAllClose(
- np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
+ np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x))
def testSparseRead1DIndices(self):
for dtype in self.numeric_types:
@@ -89,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase):
x = v.sparse_read([2, 1])
self.assertAllClose(
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
- sess.run(x))
+ self.evaluate(x))
def testSparseRead2DIndices(self):
for dtype in self.numeric_types:
@@ -102,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllClose(
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
- sess.run(x))
+ self.evaluate(x))
def testSparseRead2DIndices3DTensor(self):
for dtype in self.numeric_types:
@@ -115,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase):
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
- [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]
- ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
- ],).astype(dtype), sess.run(x))
+ [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
+ [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
+ ],).astype(dtype), self.evaluate(x))
def testShape(self):
for dtype in self.numeric_types:
@@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertAllEqual(sess.run(read), [[3], [7]])
+ self.assertAllEqual(self.evaluate(read), [[3], [7]])
def testScatterSub(self):
with self.test_session() as sess, self.test_scope():
@@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertAllEqual(sess.run(read), [[4], [-1]])
+ self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterMul(self):
with self.test_session() as sess, self.test_scope():
@@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[5]])
+ self.assertEqual(self.evaluate(read), [[5]])
def testScatterDiv(self):
with self.test_session() as sess, self.test_scope():
@@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertAllEqual(sess.run(read), [[2]])
+ self.assertAllEqual(self.evaluate(read), [[2]])
def testScatterMin(self):
with self.test_session() as sess, self.test_scope():
@@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[3]])
+ self.assertEqual(self.evaluate(read), [[3]])
def testScatterMax(self):
with self.test_session() as sess, self.test_scope():
@@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[6]])
+ self.assertEqual(self.evaluate(read), [[6]])
def testScatterUpdate(self):
with self.test_session() as sess, self.test_scope():
@@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_update(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[3]])
+ self.assertEqual(self.evaluate(read), [[3]])
def testScatterAddScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[3]])
+ self.assertEqual(self.evaluate(read), [[3]])
def testScatterSubScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[-1]])
+ self.assertEqual(self.evaluate(read), [[-1]])
def testScatterMulScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[5]])
+ self.assertEqual(self.evaluate(read), [[5]])
def testScatterDivScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[2]])
+ self.assertEqual(self.evaluate(read), [[2]])
def testScatterMinScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[3]])
+ self.assertEqual(self.evaluate(read), [[3]])
def testScatterMaxScalar(self):
with self.test_session() as sess, self.test_scope():
@@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(sess.run(read), [[6]])
+ self.assertEqual(self.evaluate(read), [[6]])
def testScatterNdAddOps(self):
with self.test_session() as sess, self.test_scope():
@@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
- self.assertAllClose(expected, sess.run(read))
+ self.assertAllClose(expected, self.evaluate(read))
def testScatterNdUpdateAddOps(self):
with self.test_session() as sess, self.test_scope():
@@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
- self.assertAllClose(expected, sess.run(read))
+ self.assertAllClose(expected, self.evaluate(read))
class StridedSliceAssignChecker(object):
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 28d61fb07dcb665fa0dbe3f3e566e291e24fa662..ef55292b1be91a731ec556d7efa9cdf1a696e5cc 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
with self.cached_session() as sess:
with self.test_scope():
x = gen_control_flow_ops.control_trigger()
- sess.run(x)
+ self.evaluate(x)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index e0171415492658a76b25167107e01300ee4bde88..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -9,6 +9,7 @@ package_group(
"//tensorflow/compiler/jit/...",
"//tensorflow/compiler/tests/...",
"//tensorflow/compiler/tf2xla/...",
+ "//tensorflow/contrib/compiler/...",
],
)
@@ -195,8 +196,8 @@ cc_library(
":sharding_util",
":side_effect_util",
":tf2xla_util",
+ "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_cluster_util",
- "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -204,13 +205,13 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -221,6 +222,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
@@ -437,21 +439,15 @@ cc_library(
name = "dump_graph",
srcs = [
"dump_graph.cc",
- "dump_graph_flags.cc",
- "dump_graph_flags.h",
],
hdrs = [
"dump_graph.h",
],
deps = [
- "//tensorflow/compiler/xla:parse_flags_from_env",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/compiler/jit:flags",
"//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
+ "//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
- "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 380c6a7e23da92d949b26876836b999bf6406c6c..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -18,87 +18,26 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/compiler/jit/flags.h"
+#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace dump_graph {
-namespace {
-
-struct NameCounts {
- mutex counts_mutex;
- std::unordered_map counts;
-};
-
-string MakeUniqueFilename(string name) {
- static NameCounts& instance = *new NameCounts;
-
- // Remove illegal characters from `name`.
- for (int i = 0; i < name.size(); ++i) {
- char ch = name[i];
- if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') {
- name[i] = '_';
- }
- }
-
- int count;
- {
- mutex_lock lock(instance.counts_mutex);
- count = instance.counts[name]++;
- }
-
- string filename = name;
- if (count > 0) {
- absl::StrAppend(&filename, "_", count);
- }
- absl::StrAppend(&filename, ".pbtxt");
- return filename;
-}
-
-string WriteTextProtoToUniqueFile(
- Env* env, const string& name, const char* proto_type,
- const ::tensorflow::protobuf::Message& proto) {
- const string& dirname =
- legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix;
- Status status = env->RecursivelyCreateDir(dirname);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create " << dirname << " for dumping "
- << proto_type << ": " << status;
- return "(unavailable)";
- }
- string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name));
- status = WriteTextProto(Env::Default(), filepath, proto);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
- << " : " << status;
- return "(unavailable)";
- }
- LOG(INFO) << "Dumped " << proto_type << " to " << filepath;
- return filepath;
-}
-
-} // anonymous namespace
-
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) {
- return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef",
- graph_def);
+ return tensorflow::DumpGraphDefToFile(
+ name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix);
}
string DumpGraphToFile(const string& name, Graph const& graph,
const FunctionLibraryDefinition* flib_def) {
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
- if (flib_def) {
- *graph_def.mutable_library() = flib_def->ToProto();
- }
- return DumpGraphDefToFile(name, graph_def);
+ return tensorflow::DumpGraphToFile(name, graph, flib_def,
+ GetDumpGraphFlags()->tf_dump_graph_prefix);
}
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) {
- return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef);
+ return tensorflow::DumpFunctionDefToFile(
+ name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix);
}
} // namespace dump_graph
diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc
deleted file mode 100644
index 2eb1f8cd849b67922f94cfe3f88456b0d6beeaf8..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's dump_graph module.
-
-#include
-#include
-
-#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
-#include "tensorflow/compiler/xla/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static DumpGraphFlags* flags;
-static std::vector* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new DumpGraphFlags;
- flags->tf_dump_graph_prefix = "/tmp/";
- flag_list = new std::vector({
- Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix,
- "Path prefix to which graphs dumped during debugging should be "
- "written."),
- });
- xla::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// dump_graph module.
-void AppendDumpGraphFlags(std::vector* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the DumpGraphFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-DumpGraphFlags* GetDumpGraphFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h
deleted file mode 100644
index 80a3307d920f2cc3d668d507786a02e43589f86f..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/tf2xla/dump_graph_flags.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
-#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
-
-// Legacy flags for the XLA bridge's dump_graph module.
-
-#include
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// dump_graph module.
-void AppendDumpGraphFlags(std::vector* flag_list);
-
-// The values of flags associated with the XLA bridge's
-// dump_graph module.
-typedef struct {
- string tf_dump_graph_prefix; // Path prefix to which graphs dumped during
- // debugging should be written.
-} DumpGraphFlags;
-
-// Return a pointer to the DumpGraphFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-DumpGraphFlags* GetDumpGraphFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 9ef9f49f422ec4dfaf538ac3c0754ba3609d3f88..3dfd3f854c8646ebbf06d3378201d22e8741b7eb 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -75,6 +75,25 @@ Status FunctionalizeControlFlow(Graph* graph,
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
+Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
+ FunctionLibraryDefinition* library) {
+ return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr,
+ graph_def, library);
+}
+
+Status FunctionalizeControlFlowForGraphDef(
+ const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
+ FunctionLibraryDefinition* library) {
+ FunctionDefLibrary function_lib = graph_def->library();
+ Graph graph(OpRegistry::Global());
+
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library));
+ graph.ToGraphDef(graph_def);
+ std::swap(*graph_def->mutable_library(), function_lib);
+ return Status::OK();
+}
+
Status FunctionalizeControlFlowForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map& attrs,
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index ba99205640ccdc83a3a4d50e3ec474907894a835..91d33fa405834d7f1f8f66180583580f4f2e448a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -33,6 +33,12 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
+Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
+ FunctionLibraryDefinition* library);
+Status FunctionalizeControlFlowForGraphDef(
+ const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
+ FunctionLibraryDefinition* library);
+
// This pass looks at the graph and all associated FunctionDefs, and turns
// traditional control flow structure (Switch/Merge/etc.) into functional
// control flow structure (If/While).
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index c3841f996f801e855da75b23f01d41674ec51c4d..9784985af83a18619d837528f99a60b98a501ec5 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -95,77 +95,87 @@ TEST(FunctionalizeControlFlow, Conditional) {
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ GraphDef optimized_graph_def;
+ graph.ToGraphDef(&optimized_graph_def);
+ TF_ASSERT_OK(
+ FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
+ GraphDef converted_graph_def;
+ graph.ToGraphDef(&converted_graph_def);
+
+ for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
+ string op_name;
+ NameAttrList then_fn;
+ NameAttrList else_fn;
+ TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
+ InstantiationResultForTest else_result;
+ TF_EXPECT_OK(
+ InstantiateFunctionForTest(else_fn.name(), library, &else_result));
+
+ // Outer graph
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
+ auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
+ auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
+ auto if_op = ops::If(scope.WithOpName(op_name), less,
+ std::initializer_list{less, y, x}, {DT_INT32},
+ then_fn, else_fn);
+ auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
+ GraphDef expected;
+ TF_EXPECT_OK(scope.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, graph_def);
+ }
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
- string op_name;
- NameAttrList then_fn;
- NameAttrList else_fn;
- TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
- InstantiationResultForTest else_result;
- TF_EXPECT_OK(
- InstantiateFunctionForTest(else_fn.name(), library, &else_result));
-
- // Outer graph
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
- auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
- auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto if_op = ops::If(scope.WithOpName(op_name), less,
- std::initializer_list{less, y, x}, {DT_INT32},
- then_fn, else_fn);
- auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
- TF_EXPECT_GRAPH_EQ(expected, graph_def);
- }
-
- // then body.
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
- auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
- auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
- auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
- auto cond = ops::Const(
- scope.WithOpName("cond").WithControlDependencies(identity), 17);
- auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
- auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
-
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
- InstantiationResultForTest result;
- TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
-
- EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
- EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
- TF_EXPECT_GRAPH_EQ(expected, result.gdef);
- }
+ // then body.
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
+ auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
+ auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
+ auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
+ auto cond = ops::Const(
+ scope.WithOpName("cond").WithControlDependencies(identity), 17);
+ auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
+ auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
+
+ GraphDef expected;
+ TF_EXPECT_OK(scope.ToGraphDef(&expected));
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(
+ InstantiateFunctionForTest(then_fn.name(), library, &result));
+
+ EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
+ EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
+ result.arg_types);
+ TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ }
- // else body.
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
- auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
- auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
- auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
- auto cond_1 = ops::Const(
- scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
- auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
- auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
-
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
- InstantiationResultForTest result;
- TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
-
- EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
- EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
- TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ // else body.
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
+ auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
+ auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
+ auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
+ auto cond_1 = ops::Const(
+ scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
+ auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
+ auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
+
+ GraphDef expected;
+ TF_EXPECT_OK(scope.ToGraphDef(&expected));
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(
+ InstantiateFunctionForTest(else_fn.name(), library, &result));
+
+ EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
+ EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
+ result.arg_types);
+ TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ }
}
}
@@ -239,75 +249,77 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ GraphDef optimized_graph_def;
+ graph.ToGraphDef(&optimized_graph_def);
+ TF_ASSERT_OK(
+ FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
+ GraphDef converted_graph_def;
+ graph.ToGraphDef(&converted_graph_def);
+
+ for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
+ NameAttrList cond_fn, body_fn;
+ TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
+
+ // Outer graph
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
+ auto while_op =
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list{source}, cond_fn, body_fn);
+ auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
+ GraphDef expected;
+ TF_EXPECT_OK(scope.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, graph_def);
+ }
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
-
- NameAttrList cond_fn, body_fn;
- TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
-
- // Outer graph
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
- auto while_op =
- ops::While(scope.WithOpName("while/LoopCond"),
- std::initializer_list{source}, cond_fn, body_fn);
- auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
- TF_EXPECT_GRAPH_EQ(expected, graph_def);
- }
-
- // Condition graph
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
- auto ten = ops::Const(
- scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
- auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
- auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
-
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
- InstantiationResultForTest result;
- TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
-
- EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
- EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
- TF_EXPECT_GRAPH_EQ(expected, result.gdef);
- }
-
- // Body graph.
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
- auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
- auto one = ops::Const(
- scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
- auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
- auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
-
- GraphDef expected;
- TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
- InstantiationResultForTest result;
- TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
+ // Condition graph
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
+ auto ten = ops::Const(
+ scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
+ auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
+ auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
+
+ GraphDef expected;
+ TF_EXPECT_OK(scope.ToGraphDef(&expected));
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(
+ InstantiateFunctionForTest(cond_fn.name(), library, &result));
+
+ EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
+ EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
+ TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ }
- EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
- EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
- TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ // Body graph.
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
+ auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
+ auto one = ops::Const