diff --git a/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..34ba4cf96017bb0dc15e74eee5d6ce211cf1058d
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md
@@ -0,0 +1,34 @@
+---
+name: Bug/Performance Issue
+about: Use this template for reporting a bug or a performance issue.
+
+---
+
+Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
+
+**System information**
+- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
+- TensorFlow installed from (source or binary):
+- TensorFlow version (use command below):
+- Python version:
+- Bazel version (if compiling from source):
+- GCC/Compiler version (if compiling from source):
+- CUDA/cuDNN version:
+- GPU model and memory:
+
+
+You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
+You can also obtain the TensorFlow version with
+python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
+
+**Describe the current behavior**
+
+**Describe the expected behavior**
+
+**Code to reproduce the issue**
+Provide a reproducible test case that is the bare minimum necessary to generate the problem.
+
+**Other info / logs**
+Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
diff --git a/.github/ISSUE_TEMPLATE/10-build-installation-issue.md b/.github/ISSUE_TEMPLATE/10-build-installation-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..99c2fe61271fb51cce8aaf94d06d9d4a633aede4
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/10-build-installation-issue.md
@@ -0,0 +1,29 @@
+---
+name: Build/Installation Issue
+about: Use this template for build/installation issues
+
+---
+
+Please make sure that this is a build/installation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:build_template
+
+**System information**
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
+- TensorFlow installed from (source or binary):
+- TensorFlow version:
+- Python version:
+- Installed using virtualenv? pip? conda?:
+- Bazel version (if compiling from source):
+- GCC/Compiler version (if compiling from source):
+- CUDA/cuDNN version:
+- GPU model and memory:
+
+
+
+**Describe the problem**
+
+**Provide the exact sequence of commands / steps that you executed before running into the problem**
+
+
+**Any other info / logs**
+Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
diff --git a/.github/ISSUE_TEMPLATE/20-documentation-issue.md b/.github/ISSUE_TEMPLATE/20-documentation-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..7123ca6d6c507315dd3470e1813ac9dd17ba8fcd
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/20-documentation-issue.md
@@ -0,0 +1,17 @@
+---
+name: Documentation Issue
+about: Use this template for documentation related issues
+
+---
+
+Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template
+
+
+**System information**
+- TensorFlow version:
+- Doc Link:
+
+
+**Describe the documentation issue**
+
+**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?**
diff --git a/.github/ISSUE_TEMPLATE/30-feature-request.md b/.github/ISSUE_TEMPLATE/30-feature-request.md
new file mode 100644
index 0000000000000000000000000000000000000000..71df2e5e49f9e42a23a8c453da5335cfbbbb6211
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/30-feature-request.md
@@ -0,0 +1,22 @@
+---
+name: Feature Request
+about: Use this template for raising a feature request
+
+---
+
+Please make sure that this is a feature request. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template
+
+
+**System information**
+- TensorFlow version (you are using):
+- Are you willing to contribute it (Yes/No):
+
+
+
+**Describe the feature and the current behavior/state.**
+
+**Will this change the current api? How?**
+
+**Who will benefit with this feature?**
+
+**Any Other info.**
diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b391279e479ade4ed5327728f19be8752e11507
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
@@ -0,0 +1,24 @@
+---
+name: TensorFlow Lite Op Request
+about: Use this template for reporting ops you are using or missing.
+
+---
+
+
+**System information**
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- TensorFlow installed from (source or binary):
+- TensorFlow version (or github SHA if from source):
+
+
+**Provide the text output from tflite_convert**
+
+```
+# Copy and paste here
+```
+
+Also, please include a link to a GraphDef or the model if possible.
+
+**Any other info / logs**
+
+Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
diff --git a/.github/ISSUE_TEMPLATE/50-other-issues.md b/.github/ISSUE_TEMPLATE/50-other-issues.md
new file mode 100644
index 0000000000000000000000000000000000000000..2d78d9818bb69ebc7b0807afe5297051494c991e
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/50-other-issues.md
@@ -0,0 +1,13 @@
+---
+name: Other Issues
+about: Use this template for any other non-support related issues
+
+---
+
+This template is for miscellaneous issues not covered by the other issue categories.
+
+For questions on how to work with TensorFlow, or support for problems that are not verified bugs in TensorFlow, please go to [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow).
+
+If you are reporting a vulnerability, please use the [dedicated reporting process](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md).
+
+For high-level discussions about TensorFlow, please post to discuss@tensorflow.org, for questions about the development or internal workings of TensorFlow, or if you would like to know how to contribute to TensorFlow, please post to developers@tensorflow.org.
diff --git a/.gitignore b/.gitignore
index cb65f447d4a551266e237714a16d71b58bcfc51d..90324058600bee46af56e49028977971848a80de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
.DS_Store
.ipynb_checkpoints
node_modules
+/.bazelrc
/.tf_configure.bazelrc
/bazel-*
/bazel_pip
@@ -23,10 +24,10 @@ Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
-/tensorflow/contrib/lite/downloads/**
-/tensorflow/contrib/lite/gen/**
-/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
-/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
+/tensorflow/lite/tools/make/downloads/**
+/tensorflow/lite/gen/**
+/tensorflow/lite/examples/ios/simple/data/*.txt
+/tensorflow/lite/examples/ios/simple/data/*.tflite
xcuserdata/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
diff --git a/BUILD b/BUILD
index 4bf647e47aa56cff0b3fd5af7d5df99d8b70549b..1200cf5f7103cad12ab9693c339c372f4f3bc0fb 100644
--- a/BUILD
+++ b/BUILD
@@ -2,5 +2,7 @@ exports_files(
[
"LICENSE",
"ACKNOWLEDGEMENTS",
+ "configure",
+ "configure.py",
],
)
diff --git a/CODEOWNERS b/CODEOWNERS
index 94cc865479cd6ab5cdb589490d3a2d650f06b160..54a61a4d72c40d297d90d53e223f64f813d9167d 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,6 +1,7 @@
# Where component owners are known, add them here.
/tenosrflow/core/debug @caisq
+/tensorflow/core/nccl/ @azaks @csigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
@@ -46,7 +47,6 @@
/tensorflow/contrib/losses/ @alextp @ispirmustafa
/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
-/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq
/tensorflow/contrib/opt/ @strategist333 @alextp
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 5fff9d05a1c589636bc9c711e6eb7cc4aba86b2f..a4647020ff76830badd75f3d3f76a41a637159bb 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -7,19 +7,22 @@ In the interest of fostering an open and welcoming environment, we as contributo
Examples of behavior that contributes to creating a positive environment include:
-* Using welcoming and inclusive language
-* Being respectful of differing viewpoints and experiences
-* Gracefully accepting constructive criticism
-* Focusing on what is best for the community
-* Showing empathy towards other community members
+* Using welcoming and inclusive language.
+* Being respectful of differing viewpoints and experiences.
+* Gracefully accepting constructive criticism.
+* Focusing on what is best for the community.
+* Showing empathy towards other community members.
Examples of unacceptable behavior by participants include:
-* The use of sexualized language or imagery and unwelcome sexual attention or advances
-* Trolling, insulting/derogatory comments, and personal or political attacks
-* Public or private harassment
-* Publishing others' private information, such as a physical or electronic address, without explicit permission
-* Conduct which could reasonably be considered inappropriate for the forum in which it occurs.
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances.
+* Trolling, insulting/derogatory comments, and personal or political attacks.
+* Public or private harassment.
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission.
+* Conduct which could reasonably be considered inappropriate for the forum in
+ which it occurs.
All TensorFlow forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable.
@@ -48,10 +51,12 @@ However, for the vast majority of issues, we aim to empower individuals to first
If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict:
-1. Address the perceived conflict directly with those involved, preferably in a real-time medium.
-2. If this fails, get a third party (e.g. a mutual friend, and/or someone with background on the issue, but not involved in conflict) to intercede.
-3. If you are still unable to resolve the conflict, and you believe it rises to harassment or another code of conduct violation, report it.
-
+1. Address the perceived conflict directly with those involved, preferably in a
+ real-time medium.
+2. If this fails, get a third party (e.g. a mutual friend, and/or someone with
+ background on the issue, but not involved in the conflict) to intercede.
+3. If you are still unable to resolve the conflict, and you believe it rises to
+ harassment or another code of conduct violation, report it.
## Reporting Violations
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index f598999f351c10f8bd01dfbd3ad8897f19d570e8..4a296f265f7b9521c46d350cec26ff199f43eb6c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -31,8 +31,12 @@ Follow either of the two links above to access the appropriate CLA and instructi
If you have improvements to TensorFlow, send us your pull requests! For those
just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/).
-TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests.
-For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally.
+TensorFlow team members will be assigned to review your pull requests. Once the
+pull requests are approved and pass continuous integration checks, a TensorFlow
+team member will apply `ready to pull` label to your change. This means we are
+working on getting your pull request submitted to our internal repository. After
+the change has been submitted internally, your pull request will be merged
+automatically on GitHub.
If you want to contribute but you're not sure where to start, take a look at the
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
diff --git a/ISSUES.md b/ISSUES.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b330e8e0a8a3f64753cfb7a2e2362222439312d
--- /dev/null
+++ b/ISSUES.md
@@ -0,0 +1,9 @@
+If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
+issue or a feature request or a build issue or a documentation issue (for small
+doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
+out. 3. The issue should be related to the repo it is created in.
+
+**Here's why we have this policy:** We want to focus on the work that benefits
+the whole community, e.g., fixing bugs and adding features. Individual support
+should be seeked on StackOverflow or other non-GitHub channels. It helps us to
+address bugs and feature requests in a timely manner.
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md
index 52faed9297cfcaf8c93bb9c79686c9258a53c560..b3d84ad8c948df9459a8e8afb029785d6f6ad335 100644
--- a/ISSUE_TEMPLATE.md
+++ b/ISSUE_TEMPLATE.md
@@ -29,9 +29,11 @@ You can collect some of this information using our environment capture script:
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
-You can obtain the TensorFlow version with
+You can obtain the TensorFlow version with:
+```bash
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
+```
### Describe the problem
Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request.
diff --git a/README.md b/README.md
index c3455474260b2db56f1f585b70af9c259704d01a..044174947a094d43a51f7140dd40ec0f17801d40 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,14 @@
|-----------------|
| [](https://www.tensorflow.org/api_docs/) |
-**TensorFlow** is an open source software library for numerical computation using
-data flow graphs. The graph nodes represent mathematical operations, while
+**TensorFlow** is an open source software library for numerical computation
+using data flow graphs. The graph nodes represent mathematical operations, while
the graph edges represent the multidimensional data arrays (tensors) that flow
-between them. This flexible architecture enables you to deploy computation to one
-or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
-code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit.
+between them. This flexible architecture enables you to deploy computation to
+one or more CPUs or GPUs in a desktop, server, or mobile device without
+rewriting code. TensorFlow also includes
+[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization
+toolkit.
TensorFlow was originally developed by researchers and engineers
working on the Google Brain team within Google's Machine Intelligence Research
@@ -79,9 +81,10 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
-tracking requests and bugs. So please see
-[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions
-and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
+tracking requests and bugs, so please see
+[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
+for general questions and discussion, and please direct specific questions to
+[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
The TensorFlow project strives to abide by generally accepted best practices in open-source software development:
@@ -107,25 +110,27 @@ The TensorFlow project strives to abide by generally accepted best practices in
### Community Supported Builds
-Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
-**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
-**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA
-**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
-**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
-**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
-**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl)
+Build Type | Status | Artifacts
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
+**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
+**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
+**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
+**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
+**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
+**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl)
## For more information
-* [TensorFlow Website](https://www.tensorflow.org)
-* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
-* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
-* [TensorFlow Twitter](https://twitter.com/tensorflow)
-* [TensorFlow Blog](https://medium.com/tensorflow)
-* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
-* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
-* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+
+* [TensorFlow Website](https://www.tensorflow.org)
+* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
+* [TensorFlow Blog](https://medium.com/tensorflow)
+* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
+* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
+* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
+* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
diff --git a/RELEASE.md b/RELEASE.md
index 20e1d9217b7684e696d0abf427eef9ab9548d1b7..b13b071bd6cf4d3a260c8e248a67d23e1a688498 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,74 @@
+# Release 1.12.0
+
+## Major Features and Improvements
+
+* Keras models can now be directly exported to the SavedModel
+ format(`tf.contrib.saved_model.save_keras_model()`) and used with Tensorflow
+ Serving.
+* Keras models now support evaluating with a `tf.data.Dataset`.
+* TensorFlow binaries are built with XLA support linked in by default.
+
+## Bug Fixes and Other Changes
+
+* tf.data:
+ * tf.data users can now represent, get, and set options of TensorFlow
+ input pipelines using `tf.data.Options()`, `tf.data.Dataset.options()`,
+ and `tf.data.Dataset.with_options()` respectively.
+ * New `tf.data.Dataset.reduce()` API allows users to reduce a finite
+ dataset to a single element using a user-provided reduce function.
+ * New `tf.data.Dataset.window()` API allows users to create finite windows
+ of input dataset; when combined with the `tf.data.Dataset.reduce()` API,
+ this allows users to implement customized batching.
+ * All C++ code moves to the `tensorflow::data` namespace.
+ * Add support for `num_parallel_calls` to `tf.data.Dataset.interleave`.
+* `tf.contrib`:
+ * Remove `tf.contrib.linalg`. `tf.linalg` should be used instead.
+ * Replace any calls to `tf.contrib.get_signature_def_by_key(metagraph_def,
+ signature_def_key)` with
+ `meta_graph_def.signature_def[signature_def_key]`. Catching a ValueError
+ exception thrown by `tf.contrib.get_signature_def_by_key` should be
+ replaced by catching a KeyError exception.
+* `tf.contrib.data`
+ * Deprecate, and replace by tf.data.experimental.
+* Other:
+ * Instead of jemalloc, revert back to using system malloc since it
+ simplifies build and has comparable performance.
+ * Remove integer types from `tf.nn.softplus` and `tf.nn.softsign` OpDefs.
+ This is a bugfix; these ops were never meant to support integers.
+ * Allow subslicing Tensors with a single dimension.
+ * Add option to calculate string length in Unicode characters
+ * Add functionality to SubSlice a tensor.
+ * Add searchsorted (ie lower/upper_bound) op.
+ * Add model explainability to Boosted Trees.
+ * Support negative positions for tf.substr
+ * There was previously a bug in the bijector_impl where the
+ _reduce_jacobian_det_over_event does not handle scalar ILDJ
+ implementations properly.
+ * In tf eager execution, allow re-entering a GradientTape context
+ * Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in,
+ then bazel will build TensorFlow API version 2.0. Note that TensorFlow
+ 2.0 is under active development and has no guarantees at this point.
+ * Add additional compression options to TfRecordWriter
+ * Performance improvements for regex full match operations.
+ * Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES`
+ * Remove unused dynamic learning rate support.
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+(David) Siu-Kei Muk, Ag Ramesh, Anton Dmitriev, Artem Sobolev, Avijit-Nervana,
+Bairen Yi, Bruno Goncalves, By Shen, candy.dc, Cheng Chen, Clayne Robison,
+coder3101, Dao Zhang, Elms, Fei Hu, feiquan, Geoffrey Irving, Guozhong Zhuang,
+hellcom, Hoeseong Kim, imsheridan, Jason Furmanek, Jason Zaman, Jenny Sahng,
+jiefangxuanyan, Johannes Bannhofer, Jonathan Homer, Koan-Sin Tan, kouml, Loo
+Rong Jie, Lukas Geiger, manipopopo, Ming Li, Moritz KröGer, Naurril, Niranjan
+Hasabnis, Pan Daoxin, Peng Yu, pengwa, rasmi, Roger Xin, Roland Fernandez, Sami
+Kama, Samuel Matzek, Sangjung Woo, Sergei Lebedev, Sergii Khomenko, shaohua,
+Shaohua Zhang, Shujian2015, Sunitha Kambhampati, tomguluson92, ViníCius Camargo,
+wangsiyu, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Xin Jin, Yan
+Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为
+
# Release 1.11.0
## Major Features and Improvements
@@ -20,51 +91,84 @@
## Bug Fixes and Other Changes
-* C++:
- * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure.
-* tf.data:
- * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`.
- * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files.
- * Renamed BigTable class to BigtableTable for clarity
- * Document use of the Cloud Bigtable API
- * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element.
- * Generalization of `tf.contrib.data.sliding_window_batch`.
-* INC:
- * Runtime improvements to triangular solve.
-* `tf.contrib`:
- * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`.
- * Add documentation clarifying the differences between tf.fill and tf.constant.
- * Add experimental IndexedDatasets.
- * Add selective registration target using the lite proto runtime.
- * Add simple Tensor and DataType classes to TensorFlow Lite Java
- * Add support for bitcasting to/from uint32 and uint64.
- * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator).
- * Adds leaf index modes as an argument.
- * Allow a different output shape from the input in tf.contrib.image.transform.
- * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells.
- * Deprecate self.test_session() in favor of self.session() or self.cached_session().
- * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon)
- * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one.
- * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
- * Fix toco compilation/execution on Windows
- * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in.
- * It is now safe to call any of the C API's TF_Delete\* functions on nullptr
- * Log some errors on Android to logcat
- * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models.
- * Optional bucket location check for the GCS Filesystem.
- * Performance enhancements for StringSplitOp & StringSplitV2Op.
- * Performance improvements for regex replace operations.
- * TFRecordWriter now raises an error if .write() fails.
- * TPU: More helpful error messages in TPUClusterResolvers.
- * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time.
- * The protocol used for Estimator training is now configurable in RunConfig.
- * Triangular solve performance improvements.
- * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method.
- * Update initialization of variables in Keras.
- * Updates to "constrained_optimization" in tensorflow/contrib.
- * boosted trees: adding pruning mode
- * tf.train.Checkpoint does not delete old checkpoints by default.
- * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit.
+* C++:
+ * Changed the signature of SessionFactory::NewSession so that it can
+ return a meaningful error message on failure.
+* tf.data:
+ * Remove `num_parallel_parser_calls` argument from
+ `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove
+ `num_parallel_parser_calls` argument from
+ `tf.contrib.data.make_csv_dataset()`.
+ * `tf.data.Dataset.list_files()` raises an exception at initialization
+ time if the argument matches no files.
+ * Renamed BigTable class to BigtableTable for clarity
+ * Document use of the Cloud Bigtable API
+ * Add `tf.contrib.data.reduce_dataset` which can be used to reduce a
+ dataset to a single element.
+ * Generalization of `tf.contrib.data.sliding_window_batch`.
+* INC:
+ * Runtime improvements to triangular solve.
+* `tf.contrib`:
+ * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D`
+ and `tf.keras.layers.LocallyConnected1D`. The new mode
+ (`implementation=2`) performs forward pass as a single dense matrix
+ multiplication, allowing dramatic speedups in certain scenarios (but
+ worse performance in others - see docstring). The option also allows to
+ use `padding=same`.
+ * Add documentation clarifying the differences between tf.fill and
+ tf.constant.
+ * Add experimental IndexedDatasets.
+ * Add selective registration target using the lite proto runtime.
+ * Add simple Tensor and DataType classes to TensorFlow Lite Java
+ * Add support for bitcasting to/from uint32 and uint64.
+ * Added a subclass of Estimator that can be created from a SavedModel
+ (SavedModelEstimator).
+ * Adds leaf index modes as an argument.
+ * Allow a different output shape from the input in
+ tf.contrib.image.transform.
+ * Change the state_size order of the StackedRNNCell to be natural order.
+ To keep the existing behavior, user can add reverse_state_order=True
+ when constructing the StackedRNNCells.
+ * Deprecate self.test_session() in favor of self.session() or
+ self.cached_session().
+ * Directly import tensor.proto.h (the transitive import will be removed
+ from tensor.h soon)
+ * Estimator.train() now supports tf.contrib.summary.\* summaries out of
+ the box; each call to .train() will now create a separate tfevents file
+ rather than re-using a shared one.
+ * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term
+ should not end up in the accumulator.
+ * Fix toco compilation/execution on Windows
+ * GoogleZoneProvider class added to detect which Google Cloud Engine zone
+ tensorflow is running in.
+ * It is now safe to call any of the C API's TF_Delete\* functions on
+ nullptr
+ * Log some errors on Android to logcat
+ * Match FakeQuant numerics in TFLite to improve accuracy of TFLite
+ quantized inference models.
+ * Optional bucket location check for the GCS Filesystem.
+ * Performance enhancements for StringSplitOp & StringSplitV2Op.
+ * Performance improvements for regex replace operations.
+ * TFRecordWriter now raises an error if .write() fails.
+ * TPU: More helpful error messages in TPUClusterResolvers.
+ * The legacy_init_op argument to SavedModelBuilder methods for adding
+ MetaGraphs has been deprecated. Please use the equivalent main_op
+ argument instead. As part of this, we now explicitly check for a single
+ main_op or legacy_init_op at the time of SavedModel building, whereas
+ the check on main_op was previously only done at load time.
+ * The protocol used for Estimator training is now configurable in
+ RunConfig.
+ * Triangular solve performance improvements.
+ * Unify RNN cell interface between TF and Keras. Add new
+ get_initial_state() to Keras and TF RNN cell, which will use to replace
+ the existing zero_state() method.
+ * Update initialization of variables in Keras.
+ * Updates to "constrained_optimization" in tensorflow/contrib.
+ * boosted trees: adding pruning mode
+ * tf.train.Checkpoint does not delete old checkpoints by default.
+ * tfdbg: Limit the total disk space occupied by dumped tensor data to 100
+ GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow
+ adjustment of this upper limit.
## Thanks to our Contributors
@@ -154,8 +258,8 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
* Update `tf.keras` to the Keras 2.1.6 API.
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
-* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite)
- for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md)
+* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite)
+ for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md)
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
included in the standard `pip` installation.
* Improved data-loading and text processing with:
@@ -458,7 +562,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
## Major Features And Improvements
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
preview version is now available.
-* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
+* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/lite)
dev preview is now available.
* CUDA 9.0 and cuDNN 7 support.
* Accelerated Linear Algebra (XLA):
@@ -805,7 +909,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/
* Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss.
* `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight.
* Adds time series models to contrib. See contrib/timeseries/README.md for details.
-* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs
+* Adds FULLY_CONNECTED Op to tensorflow/lite/schema.fbs
## Known Issues
* Tensorflow_gpu compilation fails with Bazel 0.5.3.
diff --git a/WORKSPACE b/WORKSPACE
index 17961829a605c2d1f2d2ba86a7c30c47618c139b..7cc08e0164a202581ad7ebbe107a9e19410e70e4 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,5 +1,7 @@
workspace(name = "org_tensorflow")
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
http_archive(
name = "io_bazel_rules_closure",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
@@ -14,6 +16,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
+http_archive(
+ name = "base_images_docker",
+ sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9",
+ strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6",
+ urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"],
+)
+
+http_archive(
+ name = "bazel_toolchains",
+ sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb",
+ strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b",
+ urls = [
+ "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "io_bazel_rules_docker",
+ sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd",
+ strip_prefix = "rules_docker-0.5.1",
+ urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"],
+)
+
+load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace")
+
+remote_config_workspace()
+
# We must check the bazel version before trying to parse any other BUILD
# files, in case the parsing of those build files depends on the bazel
# version we require here.
@@ -30,9 +59,9 @@ android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
-new_http_archive(
+http_archive(
name = "inception_v1",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
@@ -40,9 +69,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_ssd",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
@@ -50,9 +79,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "mobile_multibox",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
@@ -60,9 +89,9 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "stylize",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
@@ -70,12 +99,13 @@ new_http_archive(
],
)
-new_http_archive(
+http_archive(
name = "speech_commands",
- build_file = "models.BUILD",
+ build_file = "//:models.BUILD",
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
+
diff --git a/configure.py b/configure.py
index b564da27227ec07713f91e925ea292b35f0f02df..57a03bd17fac1a3a9942bdacf4661d021a62bbaa 100644
--- a/configure.py
+++ b/configure.py
@@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
-_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16]
+_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@@ -238,6 +238,13 @@ def setup_python(environ_cp):
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
+ # If choosen python_lib_path is from a path specified in the PYTHONPATH
+ # variable, need to tell bazel to include PYTHONPATH
+ if environ_cp.get('PYTHONPATH'):
+ python_paths = environ_cp.get('PYTHONPATH').split(':')
+ if python_lib_path in python_paths:
+ write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
+
# Write tools/python_bin_path.sh
with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
@@ -445,11 +452,12 @@ def convert_version_to_int(version):
return int(version_str)
-def check_bazel_version(min_version):
- """Check installed bazel version is at least min_version.
+def check_bazel_version(min_version, max_version):
+ """Check installed bazel version is between min_version and max_version.
Args:
min_version: string for minimum bazel version.
+ max_version: string for maximum bazel version.
Returns:
The bazel version detected.
@@ -467,6 +475,7 @@ def check_bazel_version(min_version):
min_version_int = convert_version_to_int(min_version)
curr_version_int = convert_version_to_int(curr_version)
+ max_version_int = convert_version_to_int(max_version)
# Check if current bazel version can be detected properly.
if not curr_version_int:
@@ -480,6 +489,10 @@ def check_bazel_version(min_version):
print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version)
sys.exit(0)
+ if curr_version_int > max_version_int:
+ print('Please downgrade your bazel installation to version %s or lower to '
+ 'build TensorFlow!' % min_version)
+ sys.exit(0)
return curr_version
@@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
- if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
+ if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
@@ -1182,6 +1195,7 @@ def set_tf_nccl_install_path(environ_cp):
if is_windows() or is_cygwin():
nccl_install_path = cygpath(nccl_install_path)
+ nccl_lib_path = ''
if is_windows():
nccl_lib_path = 'lib/x64/nccl.lib'
elif is_linux():
@@ -1417,11 +1431,16 @@ def set_mpi_home(environ_cp):
def valid_mpi_path(mpi_home):
exists = (
os.path.exists(os.path.join(mpi_home, 'include')) and
- os.path.exists(os.path.join(mpi_home, 'lib')))
+ (os.path.exists(os.path.join(mpi_home, 'lib')) or
+ os.path.exists(os.path.join(mpi_home, 'lib64')) or
+ os.path.exists(os.path.join(mpi_home, 'lib32'))))
if not exists:
- print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
- (os.path.join(mpi_home, 'include'),
- os.path.exists(os.path.join(mpi_home, 'lib'))))
+ print(
+ 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
+ % (os.path.join(mpi_home, 'include'),
+ os.path.exists(os.path.join(mpi_home, 'lib')),
+ os.path.exists(os.path.join(mpi_home, 'lib64')),
+ os.path.exists(os.path.join(mpi_home, 'lib32'))))
return exists
_ = prompt_loop_or_load_from_env(
@@ -1462,8 +1481,17 @@ def set_other_mpi_vars(environ_cp):
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
+ elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
+ elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
+
else:
- raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
+ raise ValueError(
+ 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
+ mpi_home, mpi_home, mpi_home)
def set_system_libs_flag(environ_cp):
@@ -1537,9 +1565,12 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.15.0')
+ check_bazel_version('0.15.0', '0.19.2')
reset_tf_configure_bazelrc()
+ # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
+ write_to_bazelrc('import %workspace%/tools/bazel.rc')
+
cleanup_makefile()
setup_python(environ_cp)
@@ -1667,6 +1698,8 @@ def main():
config_info_line('gdr', 'Build with GDR support.')
config_info_line('verbs', 'Build with libverbs support.')
config_info_line('ngraph', 'Build with Intel nGraph support.')
+ config_info_line('dynamic_kernels',
+ '(Experimental) Build kernels into separate shared objects.')
print('Preconfigured Bazel build configs to DISABLE default on features:')
config_info_line('noaws', 'Disable AWS S3 filesystem support.')
@@ -1674,8 +1707,8 @@ def main():
config_info_line('nohdfs', 'Disable HDFS support.')
config_info_line('noignite', 'Disable Apacha Ignite support.')
config_info_line('nokafka', 'Disable Apache Kafka support.')
+ config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
if __name__ == '__main__':
main()
-
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 77e3baaff198b402dc04daa1b11e4007b9906b23..fd4b94202aad24a82abef8abd16431f61a8326f0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = (
TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
)
+# @unused
+TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = (
+ TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
+
# Config setting used when building for products
# which requires restricted licenses to be avoided.
config_setting(
@@ -213,31 +218,37 @@ config_setting(
#
config_setting(
name = "no_aws_support",
- define_values = {"no_aws_support": "false"},
+ define_values = {"no_aws_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_gcp_support",
- define_values = {"no_gcp_support": "false"},
+ define_values = {"no_gcp_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_hdfs_support",
- define_values = {"no_hdfs_support": "false"},
+ define_values = {"no_hdfs_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_ignite_support",
- define_values = {"no_ignite_support": "false"},
+ define_values = {"no_ignite_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_kafka_support",
- define_values = {"no_kafka_support": "false"},
+ define_values = {"no_kafka_support": "true"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "no_nccl_support",
+ define_values = {"no_nccl_support": "true"},
visibility = ["//visibility:public"],
)
@@ -350,8 +361,9 @@ package_group(
"-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
- "//tensorflow_estimator/...",
+ "//tensorflow_estimator/contrib/...",
"//tensorflow_fold/llgtm/...",
+ "//tensorflow_text/...",
"//third_party/py/tensor2tensor/...",
],
)
@@ -553,35 +565,45 @@ genrule(
}),
outs = ["__init__.py"],
cmd = select({
- "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
- "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
}),
)
gen_api_init_files(
name = "tf_python_api_gen_v1",
- srcs = ["api_template.__init__.py"],
+ srcs = [
+ "api_template_v1.__init__.py",
+ "compat_template_v1.__init__.py",
+ ],
api_version = 1,
+ compat_api_versions = [1],
+ compat_init_templates = ["compat_template_v1.__init__.py"],
output_dir = "_api/v1/",
- output_files = TENSORFLOW_API_INIT_FILES_V1,
+ output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT,
output_package = "tensorflow._api.v1",
- root_init_template = "api_template.__init__.py",
+ root_file_name = "v1.py",
+ root_init_template = "api_template_v1.__init__.py",
)
gen_api_init_files(
name = "tf_python_api_gen_v2",
- srcs = ["api_template.__init__.py"],
+ srcs = [
+ "api_template.__init__.py",
+ "compat_template_v1.__init__.py",
+ ],
api_version = 2,
compat_api_versions = [1],
+ compat_init_templates = ["compat_template_v1.__init__.py"],
output_dir = "_api/v2/",
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
+ root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
- srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 2de740e145f93b151faf5c987808dbdf73fb4fd7..f13623b0d57d3b59bb9455a46a9fab29fee25784 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -21,41 +21,23 @@ from __future__ import print_function as _print_function
import os as _os
# pylint: disable=g-bad-import-order
-from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-
-try:
- # Add `estimator` attribute to allow access to estimator APIs via
- # "tf.estimator..."
- from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
-
- # Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
- # style imports.
- from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
- __path__ += [_os.path.dirname(estimator_api.__file__)]
- del estimator_api
-except (ImportError, AttributeError):
- print('tf.estimator package not installed.')
+from tensorflow.python.tools import component_api_helper as _component_api_helper
+_component_api_helper.package_hook(
+ parent_package_str=__name__,
+ child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
# API IMPORTS PLACEHOLDER
-from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
-contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
-del LazyLoader
-# The templated code that replaces the placeholder above sometimes
-# sets the __all__ variable. If it does, we have to be sure to add
-# "contrib".
-if '__all__' in vars():
- vars()['__all__'].append('contrib')
-
-from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-app.flags = flags # pylint: disable=undefined-variable
-
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
-_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+# We're using bitwise, but there's nothing special about that.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
+# Calls to enable and disable features.
+enable_eager_execution() # pylint: disable=undefined-variable
+
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
@@ -66,7 +48,14 @@ try:
del core
except NameError:
# Don't fail if these modules are not available.
- # For e.g. we are using this file for compat.v1 module as well and
- # 'python', 'core' directories are not under compat/v1.
+ # For e.g. this file will be originally placed under tensorflow/_api/v1 which
+ # does not have 'python', 'core' directories. Then, it will be copied
+ # to tensorflow/ which does have these two directories.
+ pass
+# Similarly for compiler. Do it separately to make sure we do this even if the
+# others don't exist.
+try:
+ del compiler
+except NameError:
pass
# pylint: enable=undefined-variable
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65bdb6cb1b5e6fb0656a12b932d767aeacfccd29
--- /dev/null
+++ b/tensorflow/api_template_v1.__init__.py
@@ -0,0 +1,72 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Bring in all of the public TensorFlow interface into this module."""
+
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
+
+import os as _os
+
+# pylint: disable=g-bad-import-order
+from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+
+from tensorflow.python.tools import component_api_helper as _component_api_helper
+_component_api_helper.package_hook(
+ parent_package_str=__name__,
+ child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
+
+# API IMPORTS PLACEHOLDER
+
+from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
+contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
+del LazyLoader
+# The templated code that replaces the placeholder above sometimes
+# sets the __all__ variable. If it does, we have to be sure to add
+# "contrib".
+if '__all__' in vars():
+ vars()['__all__'].append('contrib')
+
+from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
+app.flags = flags # pylint: disable=undefined-variable
+
+# Make sure directory containing top level submodules is in
+# the __path__ so that "from tensorflow.foo import bar" works.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+if _tf_api_dir not in __path__:
+ __path__.append(_tf_api_dir)
+
+
+# These symbols appear because we import the python package which
+# in turn imports from tensorflow.core and tensorflow.python. They
+# must come from this module. So python adds these symbols for the
+# resolution to succeed.
+# pylint: disable=undefined-variable
+try:
+ del python
+ del core
+except NameError:
+ # Don't fail if these modules are not available.
+ # For e.g. this file will be originally placed under tensorflow/_api/v1 which
+ # does not have 'python', 'core' directories. Then, it will be copied
+ # to tensorflow/ which does have these two directories.
+ pass
+# Similarly for compiler. Do it separately to make sure we do this even if the
+# others don't exist.
+try:
+ del compiler
+except NameError:
+ pass
+# pylint: enable=undefined-variable
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 17e2e292eb19029d279bc12a8328edadf96f1bb8..f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -6,11 +6,12 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
- "tf_cuda_cc_test",
"tf_copts",
"tf_cuda_library",
"tf_custom_op_library",
+ "tf_kernel_library",
)
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
# -----------------------------------------------------------------------------
# Public targets
@@ -59,6 +60,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
}),
)
@@ -94,6 +96,7 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
}) + select({
"//tensorflow:with_xla_support": [
@@ -118,13 +121,15 @@ tf_cuda_library(
":c_api",
":c_api_internal",
"//tensorflow/c/eager:c_api",
- "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
+ "//tensorflow/c/eager:c_api_internal",
+ "//tensorflow/compiler/jit:flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_platform",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/common_runtime/eager:attr_builder",
],
)
@@ -170,6 +175,28 @@ tf_cuda_library(
],
)
+tf_cuda_library(
+ name = "kernels",
+ srcs = [
+ "kernels.cc",
+ ],
+ hdrs = [
+ "kernels.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ ":c_api",
+ "//tensorflow/core:framework",
+ ],
+ }),
+)
+
# -----------------------------------------------------------------------------
# Tests
@@ -197,14 +224,18 @@ tf_cuda_cc_test(
size = "small",
srcs = ["c_api_test.cc"],
data = [
- ":test_op.so",
+ ":test_op1.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
+ kernels = [":test_op_kernel"],
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
- tags = ["noasan"],
+ tags = [
+ "no_oss", # http://b/119522529
+ "noasan",
+ ],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -215,6 +246,7 @@ tf_cuda_cc_test(
"//tensorflow/cc:grad_ops",
"//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/cc/saved_model:tag_constants",
+ "//tensorflow/compiler/jit",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
@@ -232,7 +264,7 @@ tf_cuda_cc_test(
tf_cc_test(
name = "c_api_experimental_test",
- size = "small",
+ size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"],
linkopts = select({
@@ -243,8 +275,11 @@ tf_cc_test(
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
+ ":c_api",
":c_api_experimental",
":c_test_util",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@@ -281,8 +316,42 @@ tf_cc_test(
)
tf_custom_op_library(
- name = "test_op.so",
+ name = "test_op1.so",
+ srcs = ["test_op1.cc"],
+)
+
+tf_kernel_library(
+ name = "test_op_kernel",
srcs = ["test_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_cuda_cc_test(
+ name = "kernels_test",
+ size = "small",
+ srcs = ["kernels_test.cc"],
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
+ tags = ["noasan"],
+ # We must ensure that the dependencies can be dynamically linked since
+ # the shared library must be able to use core:framework.
+ # linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":c_api",
+ ":kernels",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
)
# -----------------------------------------------------------------------------
diff --git a/tensorflow/c/README.md b/tensorflow/c/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b386998ceaf3e91daba04125fe83e2f3bdd508e5
--- /dev/null
+++ b/tensorflow/c/README.md
@@ -0,0 +1,7 @@
+# TensorFlow C API
+
+- See [www.tensorflow.org/install/lang_c](https://www.tensorflow.org/install/lang_c)
+- Nightly builds:
+ - [Linux CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-linux-x86_64.tar.gz)
+ - [Linux GPU](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-gpu-linux-x86_64.tar.gz)
+ - [MacOS CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-darwin-x86_64.tar.gz)
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 79811ceae57e0bddeb2a6f32bad7003e14e23422..f13e8777dff164bcd8eedf46310ae846abd0c804 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1942,6 +1942,10 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
const char* prefix) {
opts->opts.prefix = prefix;
}
+void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
+ const char* device) {
+ opts->opts.default_device = device;
+}
void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
unsigned char uniquify_names) {
@@ -2770,6 +2774,9 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
}
string name_str(name, name_len);
const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
+ if (api_def == nullptr) {
+ return nullptr;
+ }
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
@@ -2803,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
}
return ret;
}
+
+// TF_Server functions ----------------------------------------------
+
+#ifndef __ANDROID__
+TF_Server::TF_Server(std::unique_ptr server)
+ : target(server->target()), server(std::move(server)) {}
+#endif // __ANDROID__
+
+TF_Server* TF_NewServer(const void* proto, size_t proto_len,
+ TF_Status* status) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Server functionality is not supported in Android");
+ return nullptr;
+#else
+ tensorflow::ServerDef server_def;
+ if (!server_def.ParseFromArray(proto, static_cast(proto_len))) {
+ status->status = InvalidArgument(
+ "Could not parse provided bytes into a ServerDef protocol buffer");
+ return nullptr;
+ }
+
+ std::unique_ptr out_server;
+ status->status = tensorflow::NewServer(server_def, &out_server);
+ if (!status->status.ok()) return nullptr;
+
+ return new TF_Server(std::move(out_server));
+#endif
+}
+
+void TF_ServerStart(TF_Server* server, TF_Status* status) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Server functionality is not supported in Android");
+#else
+ status->status = server->server->Start();
+#endif
+}
+
+void TF_ServerStop(TF_Server* server, TF_Status* status) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Server functionality is not supported in Android");
+#else
+ status->status = server->server->Stop();
+#endif
+}
+
+void TF_ServerJoin(TF_Server* server, TF_Status* status) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Server functionality is not supported in Android");
+#else
+ status->status = server->server->Join();
+#endif
+}
+
+const char* TF_ServerTarget(TF_Server* server) {
+#ifdef __ANDROID__
+ return nullptr;
+#else
+ return server->target.c_str();
+#endif
+}
+
+void TF_DeleteServer(TF_Server* server) { delete server; }
+
} // end extern "C"
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 850f6ecd637d768bca99720e0add07680829e17a..3d56268110edbe96616201d15a69cc8c84d3115a 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -900,6 +900,12 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix(
TF_ImportGraphDefOptions* opts, const char* prefix);
+// Set the execution device for nodes in `graph_def`.
+// Only applies to nodes where a device was not already explicitly specified.
+// `device` is copied and has no lifetime requirements.
+TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice(
+ TF_ImportGraphDefOptions* opts, const char* device);
+
// Set whether to uniquify imported operation names. If true, imported operation
// names will be modified if their name already exists in the graph. If false,
// conflicting names will be treated as an error. Note that this option has no
@@ -1662,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status);
+// --------------------------------------------------------------------------
+// In-process TensorFlow server functionality, for use in distributed training.
+// A Server instance encapsulates a set of devices and a Session target that
+// can participate in distributed training. A server belongs to a cluster
+// (specified by a ClusterSpec), and corresponds to a particular task in a
+// named job. The server can communicate with any other server in the same
+// cluster.
+
+// In-process TensorFlow server.
+typedef struct TF_Server TF_Server;
+
+// Creates a new in-process TensorFlow server configured using a serialized
+// ServerDef protocol buffer provided via `proto` and `proto_len`.
+//
+// The server will not serve any requests until TF_ServerStart is invoked.
+// The server will stop serving requests once TF_ServerStop or
+// TF_DeleteServer is invoked.
+TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
+// Starts an in-process TensorFlow server.
+TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status);
+
+// Stops an in-process TensorFlow server.
+TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status);
+
+// Blocks until the server has been successfully stopped (via TF_ServerStop or
+// TF_ServerClose).
+TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status);
+
+// Returns the target string that can be provided to TF_SetTarget() to connect
+// a TF_Session to `server`.
+//
+// The returned string is valid only until TF_DeleteServer is invoked.
+TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server);
+
+// Destroy an in-process TensorFlow server, frees memory. If server is running
+// it will be stopped and joined.
+TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index d4b78138e93624a7e41e917f8210281b500661bc..3693cc85996365360253c8a94c29272a16e11e9a 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -15,12 +15,18 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/compiler/jit/flags.h"
+#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@@ -50,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
- tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
- tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ tensorflow::MarkForCompilationPassFlags* flags =
+ tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
@@ -70,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
- tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
- tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ tensorflow::MarkForCompilationPassFlags* flags =
+ tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
@@ -8738,7 +8744,145 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
TF_DeleteStatus(status);
}
-TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
- const char* errMsg) {
+struct TFE_ExecuteOpNotification {
+ TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
+ tensorflow::Notification n;
+ std::unique_ptr thread;
+ std::unique_ptr status;
+};
+
+TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
+ TFE_TensorHandle** retvals,
+ int* num_retvals,
+ TF_Status* status) {
+ TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
+
+ n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
+ tensorflow::ThreadOptions(), "ExecuteOpThread",
+ [op, retvals, num_retvals, n]() {
+ TFE_Execute(op, retvals, num_retvals, n->status.get());
+ n->n.Notify();
+ }));
+
+ return n;
+}
+
+void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status) {
+ if (notification == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification is a nullptr.");
+
+ return;
+ }
+ if (notification->thread == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification didn't start a thread correctly. Cleaning up "
+ "this notification. Please re-execute the operation to get a new "
+ "notification.");
+
+ delete notification;
+ return;
+ }
+
+ notification->n.WaitForNotification();
+
+ status->status = notification->status->status;
+
+ delete notification;
+}
+
+void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
+
+// This builder is used in the eager API to build a NodeDef.
+struct TF_AttrBuilder : public tensorflow::AttrBuilder {
+ using tensorflow::AttrBuilder::AttrBuilder;
+ // The string buffers to make sure that any `attr_name` we pass into
+ // `builder->Set()` will outlive the subsequent
+ // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
+ std::set attr_names;
+};
+
+TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
+ return new TF_AttrBuilder(op_name);
+}
+
+void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
+
+void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
+ TF_DataType value) {
+ auto iter = builder->attr_names.insert(attr_name).first;
+ builder->Set((*iter).c_str(), static_cast(value));
+}
+
+void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
+ const TF_DataType* values, int num_values) {
+ auto iter = builder->attr_names.insert(attr_name).first;
+ builder->Set(
+ (*iter).c_str(),
+ tensorflow::gtl::ArraySlice(
+ reinterpret_cast(values), num_values));
+}
+
+void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
+ const char* device_type,
+ TF_Status* status) {
+ status->status = tensorflow::FindKernelDef(
+ tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
+ /* def = */ nullptr, /* kernel_class_name = */ nullptr);
+}
+
+const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
+ TF_Status* status) {
+ const tensorflow::OpDef* op_def = nullptr;
+ status->status =
+ tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
+ if (!status->status.ok()) return nullptr;
+
+ if (input_index >= op_def->input_arg_size() || input_index < 0) {
+ status->status = tensorflow::errors::InvalidArgument(
+ input_index, " out of range for ", op_name);
+ return nullptr;
+ }
+
+ const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
+
+ if (input_arg.number_attr().empty()) {
+ status->status = tensorflow::errors::NotFound(
+ op_name, " does not have number_attr() defined.");
+ return nullptr;
+ }
+
+ // The returned string is owned by OpRegistry, so liveness is not a concern.
+ return input_arg.number_attr().c_str();
+}
+
+int TF_OpIsStateful(const char* op_type, TF_Status* status) {
+ const tensorflow::OpRegistrationData* op_reg_data;
+ status->status =
+ tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
+ if (!status->status.ok()) {
+ return 0;
+ }
+ return op_reg_data->op_def.is_stateful();
+}
+
+void TF_InitMain(const char* usage, int* argc, char*** argv) {
+ tensorflow::port::InitMain(usage, argc, argv);
+}
+
+int TF_PickUnusedPortOrDie() {
+ return tensorflow::internal::PickUnusedPortOrDie();
+}
+
+TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg,
+ void* data, size_t len) {
+ auto dtype = static_cast(dtype_arg);
+ DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
+
+ tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
+ std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
+ return new TFE_TensorHandle(tensor, nullptr, nullptr);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index d98d532e32e891e21f5b7ba360c74c3256fb1947..80c8bfe594c4c89606efd01bec7f50e7a86b5bda 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -180,9 +180,72 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
+typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
+
+// Allows invoking a kernel asynchronously, and explicitly returns a
+// notification that can be waited upon. This always executes the kernel in a
+// new thread.
+// 1. `retvals` and `num_retvals` can only be consumed after
+// `TFE_ExecuteOp` returns successfully. They shouldn't be used
+// if the return is unsuccessful
+// 2. These new APIs cannot be used together with the TFE context level async
+// support.
+TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
+ TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+ TF_Status* status);
+
+// Waits to complete the op execution, and cleans up the notification.
+// Errors reported by op execution are set in `status`.
+TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status);
+
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
+// TF_NewAttrBuilder() returns an object that you can set attributes on as
+// though it were an op. This allows querying properties of that op for
+// type-checking purposes like if the op will run on a particular device type.
+typedef struct TF_AttrBuilder TF_AttrBuilder;
+TF_CAPI_EXPORT extern TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name);
+TF_CAPI_EXPORT extern void TF_DeleteAttrBuilder(TF_AttrBuilder* builder);
+TF_CAPI_EXPORT extern void TF_AttrBuilderSetType(TF_AttrBuilder* builder,
+ const char* attr_name,
+ TF_DataType value);
+TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder,
+ const char* attr_name,
+ const TF_DataType* values,
+ int num_values);
+
+// Checks the tensorflow::NodeDef built via the methods above to see if it can
+// run on device_type.
+TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
+ TF_AttrBuilder* builder, const char* device_type, TF_Status* status);
+
+// For argument number input_index, fetch the corresponding number_attr that
+// needs to be updated with the argument length of the input list.
+// Returns nullptr if there is any problem like op_name is not found, or the
+// argument does not support this attribute type.
+TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
+ const char* op_name, int input_index, TF_Status* status);
+
+// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined
+// if the status is not ok.
+TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type,
+ TF_Status* status);
+
+// Platform specific initialization routine. Very few platforms actually require
+// this to be called.
+TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
+
+// Platform-specific implementation to return an unused port. (This should used
+// in tests only.)
+TF_CAPI_EXPORT int TF_PickUnusedPortOrDie();
+
+// Fast path method that makes constructing a single scalar tensor require less
+// overhead and copies.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
+ TF_DataType dtype, void* scalar, size_t len);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index c6effd39697e0397278770b53e98508074f99862..daa7701b7fe7e8ce757b6504329cf6434ad39778 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -162,5 +164,137 @@ protocol: "grpc"
TF_DeleteStatus(status);
}
+TEST(CAPI_EXPERIMENTAL, IsStateful) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ int assign = TF_OpIsStateful("AssignAddVariableOp", status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ EXPECT_EQ(assign, 1);
+ int id = TF_OpIsStateful("Identity", status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ EXPECT_EQ(id, 0);
+}
+
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ TFE_Op* matmul_op = MatMulOp(ctx, m, m);
+
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
+
+ auto* r =
+ TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(r, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteOp(matmul_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
+// Perform a send/recv test. Recv blocks, so they need to be executed
+// asynchronously.
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ // Build a send op.
+ TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(send_op, m, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ string tensor_name = "Tensor";
+ TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
+ TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
+ string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(send_op, "client_terminated", true);
+
+ // Build a recv op.
+ TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
+ TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
+ TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(recv_op, "client_terminated", true);
+
+ TFE_TensorHandle* send_retvals;
+ int send_num_retvals = 0;
+ auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
+ &send_num_retvals, status);
+
+ TFE_TensorHandle* recv_retvals[1] = {nullptr};
+ int recv_num_retvals = 1;
+ auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
+ &recv_num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1, product[0]);
+ EXPECT_EQ(2, product[1]);
+ EXPECT_EQ(3, product[2]);
+ EXPECT_EQ(4, product[3]);
+
+ TFE_DeleteOp(send_op);
+ TFE_DeleteOp(recv_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(recv_retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index f68f8a3e90a971b5e4a024feaf26ba498afc48da..28b9f8df9c873ee394eb6a241dd9ac06ba6c8796 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -392,26 +392,26 @@ Status ProcessInputs(
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
- const Node& node = inputs[i].oper->node;
+ Node* node = &inputs[i].oper->node;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- fn_body->graph.IsValidOutputTensor(&node, idx),
+ fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
- input_tensors->emplace_back(&node, idx);
+ input_tensors->emplace_back(node, idx);
- const auto& iter = input_nodes->find(&node);
+ const auto& iter = input_nodes->find(node);
if (iter == input_nodes->end()) {
- input_nodes->insert({&node, {idx}});
+ input_nodes->insert({node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
- return InvalidArgument("TF_Output ", node.name(), ":", idx,
+ return InvalidArgument("TF_Output ", node->name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
@@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
- const Node& node = outputs[i].oper->node;
+ Node* node = &outputs[i].oper->node;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- fn_body->graph.IsValidOutputTensor(&node, idx),
+ fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while creating function '",
fn_name, "'");
- output_tensors->emplace_back(&node, idx);
+ output_tensors->emplace_back(node, idx);
}
return Status::OK();
}
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index 95652a11378d6276b5ba6540a07baa15aa77cc1c..5ba26d3c585350aa510f9970cbfc246a9a108543 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -25,6 +25,7 @@ limitations under the License.
#include
#ifndef __ANDROID__
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#endif
#include "tensorflow/core/common_runtime/shape_refiner.h"
@@ -179,6 +180,15 @@ struct TF_ApiDefMap {
tensorflow::mutex lock;
};
+#ifndef __ANDROID__
+struct TF_Server {
+ TF_Server(std::unique_ptr server);
+
+ const tensorflow::string target;
+ std::unique_ptr server;
+};
+#endif
+
namespace tensorflow {
class TensorCApi {
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 03516c39dc970aa23967107d3a0446da94669465..d5934a10395ae094f65d3bc8b6cd7b94dbd32410 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@@ -186,23 +187,40 @@ TEST(CAPI, LibraryLoadFunctions) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
- // Load the library.
- TF_Status* status = TF_NewStatus();
- TF_Library* lib =
- TF_LoadLibrary("tensorflow/c/test_op.so", status);
- TF_Code code = TF_GetCode(status);
- string status_msg(TF_Message(status));
- TF_DeleteStatus(status);
- ASSERT_EQ(TF_OK, code) << status_msg;
-
- // Test op list.
- TF_Buffer op_list_buf = TF_GetOpList(lib);
- tensorflow::OpList op_list;
- EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
- ASSERT_EQ(op_list.op_size(), 1);
- EXPECT_EQ("TestCApi", op_list.op(0).name());
-
- TF_DeleteLibraryHandle(lib);
+#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+ {
+ // Load the library.
+ TF_Status* status = TF_NewStatus();
+ TF_Library* lib =
+ TF_LoadLibrary("tensorflow/c/test_op1.so", status);
+ TF_Code code = TF_GetCode(status);
+ string status_msg(TF_Message(status));
+ TF_DeleteStatus(status);
+ ASSERT_EQ(TF_OK, code) << status_msg;
+
+ // Test op list.
+ TF_Buffer op_list_buf = TF_GetOpList(lib);
+ tensorflow::OpList op_list;
+ EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
+ ASSERT_EQ(op_list.op_size(), 1);
+ EXPECT_EQ("TestCApi1", op_list.op(0).name());
+ TF_DeleteLibraryHandle(lib);
+ }
+#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+ {
+ TF_Buffer* op_list_buffer = TF_GetAllOpList();
+ tensorflow::OpList op_list;
+ op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
+ ASSERT_GE(op_list.op_size(), 1);
+ typedef tensorflow::protobuf::RepeatedPtrField OpDefs;
+ const OpDefs& ops = op_list.op();
+ bool found = std::find_if(ops.begin(), ops.end(),
+ [](const tensorflow::OpDef& op_def) {
+ return op_def.name() == "TestCApi";
+ }) != ops.end();
+ EXPECT_TRUE(found);
+ TF_DeleteBuffer(op_list_buffer);
+ }
}
void TestEncodeDecode(int line, const std::vector& data) {
@@ -2329,15 +2347,9 @@ TEST(TestApiDef, TestCreateApiDef) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
+ TF_Buffer* op_list_buf = TF_GetAllOpList();
TF_Status* status = TF_NewStatus();
- TF_Library* lib =
- TF_LoadLibrary("tensorflow/c/test_op.so", status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteStatus(status);
-
- TF_Buffer op_list_buf = TF_GetOpList(lib);
- status = TF_NewStatus();
- auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
+ auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@@ -2355,7 +2367,7 @@ TEST(TestApiDef, TestCreateApiDef) {
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
- TF_DeleteLibraryHandle(lib);
+ TF_DeleteBuffer(op_list_buf);
}
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
@@ -2363,15 +2375,9 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
+ TF_Buffer* op_list_buf = TF_GetAllOpList();
TF_Status* status = TF_NewStatus();
- TF_Library* lib =
- TF_LoadLibrary("tensorflow/c/test_op.so", status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteStatus(status);
-
- TF_Buffer op_list_buf = TF_GetOpList(lib);
- status = TF_NewStatus();
- auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
+ auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@@ -2400,7 +2406,7 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
- TF_DeleteLibraryHandle(lib);
+ TF_DeleteBuffer(op_list_buf);
}
class DummyKernel : public tensorflow::OpKernel {
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 3ee31a6a7ac641bbd3fc4c05568b61e433a1d523..c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -50,6 +50,7 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
+ "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@@ -69,7 +70,7 @@ tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
visibility = [
- "//learning/deepmind/courier:__pkg__",
+ "//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal",
],
deps = [
@@ -143,6 +144,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 3554ec0bf3202b54bfc38d67e51b89df19832302..027d752f420238da867cb9d8c116640e1730caaa 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -21,9 +21,11 @@ limitations under the License.
#include
#include
+#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/core/platform/host_info.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
const std::vector& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr* device_mgr) {
- std::vector remote_devices;
+ std::vector> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
@@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
status = s;
if (s.ok()) {
for (tensorflow::Device* d : *devices) {
- remote_devices.push_back(d);
+ remote_devices.emplace_back(d);
}
}
n.Notify();
@@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification();
}
std::unique_ptr remote_device_mgr(
- new tensorflow::DeviceMgr(remote_devices));
+ new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
@@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- std::vector devices;
+ std::vector> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr device_mgr(
- new tensorflow::DeviceMgr(devices));
+ new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
@@ -404,8 +406,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
"The passed in handle is a nullptr");
return nullptr;
}
- tensorflow::Device* d = nullptr;
- status->status = h->handle->OpDevice(&d);
+ tensorflow::Device* d = h->handle->op_device();
+ return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+ : d->name().c_str();
+}
+
+const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
+ TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+ tensorflow::Device* d = h->handle->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
@@ -459,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
- status->status = tensorflow::AttrTypeMapForOp(name, &types);
- if (status->status.ok()) return new TFE_Op(ctx, name, types);
- if (TF_GetCode(status) == TF_NOT_FOUND) {
- if (ctx->context.FindFunctionByName(name)) {
- status->status = tensorflow::Status::OK();
- return new TFE_Op(ctx, name, nullptr);
+ bool is_function = false;
+ status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
+ if (status->status.ok()) {
+ if (is_function && !ctx->context.FindFunctionByName(name)) {
+ status->status = tensorflow::errors::NotFound(
+ "'", name,
+ "' is neither a type of a primitive operation nor a name "
+ "of a function registered in binary running on ",
+ tensorflow::port::Hostname(),
+ ". Make sure the operation or function is "
+ "registered in the binary running in this process.");
+ return nullptr;
}
+ return new TFE_Op(ctx, name, is_function, types);
}
return nullptr;
}
@@ -498,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
- if (op->operation.is_function()) {
- status->status = tensorflow::errors::Unimplemented(
- "TODO(apassos): Support for attributes for TensorFlow functions is not "
- "ready yet.");
- return TF_ATTR_INT; // The compiler requires that we return something.
- }
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
return ret;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index b2454d872207e26feb3764671474a5d87c01f84d..8d6c8d958d5961fce817156a14eb2b2940c1f2f0 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
+
+// Returns the device of the operation that produced `h`.
+// If `h` was produced by a copy, returns the destination device of
+// the copy. Note that returned device name is not always the device
+// holding the tensor handle's memory. If you want the latter, use
+// TFE_TensorHandleBackingDeviceName.
+// This function will block till the operation that produces `h` has completed.
+//
+// Device on which the kernel of the operation that produced `h` ran.
+//
+// If `h` was produced by a copy, returns the destination device of
+// the copy.
+//
+// Note that returned device name is not always the device that owns the memory
+// that backs the tensor handle. For the latter see
+// TFE_TensorHandleBackingDeviceName.
+//
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Returns the name of the device in whose memory `h` resides.
+//
+// This function will block till the operation that produces `h` has completed.
+TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
+ TFE_TensorHandle* h, TF_Status* status);
+
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
// the error and a nullptr is returned.
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
index 5006b76f1981d068e99a2c081115ebb3a66d8c7f..52b0824552855860dfb138f3ac9a5d3afa7dc965 100644
--- a/tensorflow/c/eager/c_api_debug.cc
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr;
}
- tensorflow::Device* device;
- status->status = handle->handle->Device(&device);
- if (!status->status.ok()) {
- return nullptr;
- }
-
#ifdef TENSORFLOW_EAGER_USE_XLA
+ tensorflow::Device* device = handle->handle->device();
+
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
dynamic_cast(device);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 104d52430cf7aa14d4d2a335a1b96e667f21ce87..67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -79,10 +79,6 @@ struct TFE_TensorHandle {
tensorflow::Device* op_device)
: handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
- TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
- tensorflow::EagerContext* ctx)
- : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
-
TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
tensorflow::TensorHandle* handle;
@@ -97,10 +93,9 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
- // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
- // primitive operation.
- TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
- : operation(&ctx->context, op, t) {}
+ TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
+ const tensorflow::AttrTypeMap* t)
+ : operation(&ctx->context, op, is_function, t) {}
tensorflow::EagerOperation operation;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 55331022b9dbd0696928fa44430f340f371432ac..6b39b79ee82f9c7baaf856e573a42b7da65691e5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include
+#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
+ bool has_gpu0 = false;
+ bool has_gpu1 = false;
+ for (int i = 0; i < num_devices; ++i) {
+ const char* dev = TF_DeviceListName(devices, i, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ string device_name(dev);
+ if (device_name.find("GPU:0") != string::npos) {
+ has_gpu0 = true;
+ }
+ if (device_name.find("GPU:1") != string::npos) {
+ has_gpu1 = true;
+ }
+ }
const char* kCPUDevice = "CPU:0";
- if (num_devices < 3) {
+ if (!has_gpu0 || !has_gpu1) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
@@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) {
TF_SetStatus(status.get(), TF_OK, "");
+ device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(device_name, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
@@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) {
string(TF_Message(status.get())));
}
+TEST(CAPI, TensorHandleDevices) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
+ const char* backing_device_name =
+ TFE_TensorHandleBackingDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ // Disable the test if no GPU is present.
+ string gpu_device_name;
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
+ TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
+ hcpu, ctx, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_Op* shape_op = ShapeOp(ctx, hgpu);
+ TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // .device of shape is GPU since the op is executed on GPU
+ device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
+
+ // .backing_device of shape is CPU since the tensor is backed by CPU
+ backing_device_name =
+ TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ TFE_DeleteOp(shape_op);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
+ }
+
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
+}
+
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 008f088c2dcdd7d9114103516a4702e47a55c6de..bd38127d50c171af801dd1b937acefdba491b4a6 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
return op;
}
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
TFE_TensorHandle* TestAxisTensorHandle() {
int64_t dims[] = {1};
int data[] = {1};
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
index 474cae67c89249af3a62707f0db00ba458ca8f31..75ef9459e93b4f2ed471c423a34565594efc1714 100644
--- a/tensorflow/c/eager/c_api_test_util.h
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
+// Return a shape op fetching the shape of `a`.
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
+
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle();
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 5ba55a203ff70cc64c07e96b5a869a1f11c9334e..5c11f51e8749de84547ae873f5f55ebd42bc4b3d 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -141,8 +141,9 @@ class GradientTape {
// null. The result is populated with one tensor per target element.
Status ComputeGradient(
const VSpace& vspace,
- gtl::ArraySlice target_tensor_ids,
- gtl::ArraySlice source_tensor_id,
+ const gtl::ArraySlice target_tensor_ids,
+ const gtl::ArraySlice source_tensor_ids,
+ const gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients,
std::vector* result);
@@ -396,6 +397,7 @@ template
Status InitialGradients(
const VSpace& vspace,
gtl::ArraySlice target_tensor_ids,
+ gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients, const TensorTape& tensor_tape,
const OpTape& op_tape,
gtl::FlatMap>* result) {
@@ -425,8 +427,13 @@ Status InitialGradients(
"none of operations outputs match expected tensor");
}
} else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ // This target tensor was not generated by any operation recorded on
+ // the tape, so no gradient needs to be computed from it unless this
+ // target is also a source.
+ auto source_tensor = sources_that_are_targets.find(id);
+ if (source_tensor != sources_that_are_targets.end()) {
+ (*result)[id].push_back(vspace.Ones(source_tensor->second));
+ }
}
} else {
(*result)[id].push_back(output_gradients[i]);
@@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
template
Status GradientTape::ComputeGradient(
const VSpace& vspace,
- gtl::ArraySlice target_tensor_ids,
- gtl::ArraySlice source_tensor_ids,
+ const gtl::ArraySlice target_tensor_ids,
+ const gtl::ArraySlice source_tensor_ids,
+ const gtl::FlatMap sources_that_are_targets,
gtl::ArraySlice output_gradients,
std::vector* result) {
gtl::FlatSet sources_set(source_tensor_ids.begin(),
@@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient(
std::vector op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap> gradients;
- Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
+ Status s = InitialGradients(vspace, target_tensor_ids,
+ sources_that_are_targets, output_gradients,
tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
new file mode 100644
index 0000000000000000000000000000000000000000..ca69345264607ac689fb556b4f5c9bc08ea5eb88
--- /dev/null
+++ b/tensorflow/c/kernels.cc
@@ -0,0 +1,118 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+
+#include "tensorflow/c/kernels.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+// This file forms the basis of a stable ABI for third-party kernel
+// implementations. It is crucial that changes to this file are made cautiously
+// and with a focus on maintaining both source and binary compatibility.
+
+struct TF_KernelBuilder {
+ ::tensorflow::KernelDefBuilder* cc_builder;
+
+ void* (*create_function)(TF_OpKernelConstruction*);
+ void (*compute_function)(void*, TF_OpKernelContext*);
+ void (*delete_function)(void*);
+};
+
+TF_KernelBuilder* TF_NewKernelBuilder(
+ const char* op_name, const char* device_name,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*)) {
+ TF_KernelBuilder* result = new TF_KernelBuilder;
+ result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
+ result->cc_builder->Device(device_name);
+ result->create_function = create_func;
+ result->compute_function = compute_func;
+ result->delete_function = delete_func;
+ return result;
+}
+
+void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
+ DCHECK_NE(builder, nullptr);
+ delete builder->cc_builder;
+ delete builder;
+}
+
+namespace tensorflow {
+namespace {
+
+// An OpKernel whose methods delegate to C function pointers.
+class COpKernel : public OpKernel {
+ public:
+ explicit COpKernel(OpKernelConstruction* ctx,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*))
+ : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
+ if (create_func != nullptr) {
+ c_kernel_ =
+ (*create_func)(reinterpret_cast(ctx));
+ } else {
+ c_kernel_ = nullptr;
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ (*compute_func_)(c_kernel_, reinterpret_cast(ctx));
+ }
+
+ ~COpKernel() override {
+ if (delete_func_ != nullptr) {
+ (*delete_func_)(c_kernel_);
+ }
+ }
+
+ private:
+ void (*compute_func_)(void*, TF_OpKernelContext* context);
+ void (*delete_func_)(void*);
+ void* c_kernel_;
+};
+
+// A KernelFactory that returns COpKernel instances.
+class KernelBuilderFactory
+ : public ::tensorflow::kernel_factory::OpKernelFactory {
+ public:
+ explicit KernelBuilderFactory(TF_KernelBuilder* builder)
+ : builder_(builder) {}
+ ::tensorflow::OpKernel* Create(
+ ::tensorflow::OpKernelConstruction* context) override {
+ return new ::tensorflow::COpKernel(context, builder_->create_function,
+ builder_->compute_function,
+ builder_->delete_function);
+ }
+ ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
+
+ private:
+ TF_KernelBuilder* builder_;
+};
+} // namespace
+} // namespace tensorflow
+
+void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
+ TF_Status* status) {
+ using tensorflow::register_kernel::Name;
+
+ tensorflow::kernel_factory::OpKernelRegistrar(
+ builder->cc_builder->Build(), name,
+ absl::make_unique(builder));
+
+ TF_SetStatus(status, TF_OK, "");
+}
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..2518789a3c141755d0b3373d53642c487331f68b
--- /dev/null
+++ b/tensorflow/c/kernels.h
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_KERNELS_H_
+#define TENSORFLOW_C_KERNELS_H_
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// --------------------------------------------------------------------------
+// C API for TensorFlow Kernels.
+//
+// This API allows developers to register custom kernel implementations for
+// TensorFlow.
+//
+// See c_api.h header comments for a discussion about API conventions.
+//
+// Users wishing to extend TensorFlow with new kernels will call
+// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
+// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
+// kernels when necessary.
+
+struct TF_KernelBuilder;
+struct TF_OpKernelConstruction;
+struct TF_OpKernelContext;
+
+// Allocates a new kernel builder and returns a pointer to it.
+//
+// If non-null, TensorFlow will call create_func when it needs to instantiate
+// the kernel. The pointer returned by create_func will be passed to
+// compute_func and delete_func, thereby functioning as a "this" pointer for
+// referring to kernel instances.
+//
+// The TF_OpKernelConstruction pointer passed to create_func is owned by
+// TensorFlow and will be deleted once create_func returns. It must not be used
+// after this.
+//
+// When TensorFlow needs to perform a computation with this kernel, it will
+// call compute_func. This function will receive the pointer returned by
+// create_func (or null if no create_func was provided), along with the inputs
+// to the computation.
+//
+// The TF_OpKernelContext pointer received by compute_func is owned by
+// TensorFlow and will be deleted once compute_func returns. It must not be used
+// after this.
+//
+// Finally, when TensorFlow no longer needs the kernel, it will call
+// delete_func if one is provided. This function will receive the pointer
+// returned in `create_func` or nullptr if no `create_func` was provided.
+//
+// The caller should pass the result of this function to
+// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
+// some reason, the kernel builder will not be registered, the caller should
+// delete it with TF_DeleteKernelBuilder.
+TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
+ const char* op_name, const char* device_name,
+ void* (*create_func)(TF_OpKernelConstruction*),
+ void (*compute_func)(void*, TF_OpKernelContext*),
+ void (*delete_func)(void*));
+
+// Register the given kernel builder with the TensorFlow runtime. If
+// registration fails, the given status will be populated.
+//
+// This call takes ownership of the `builder` pointer.
+TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
+ TF_KernelBuilder* builder,
+ TF_Status* status);
+
+// Deletes the given TF_KernelBuilder. This should be called only if the kernel
+// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
+TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
+
+#ifdef __cplusplus
+} /* end extern "C" */
+#endif
+
+#endif // TENSORFLOW_C_KERNELS_H_
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861
--- /dev/null
+++ b/tensorflow/c/kernels_test.cc
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/kernels.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def.pb_text.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+struct MyCustomKernel {
+ bool created;
+ bool compute_called;
+};
+
+static bool delete_called = false;
+
+static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
+ LOG(INFO) << "Wow, actually got into creation";
+ struct MyCustomKernel* s = new struct MyCustomKernel;
+ s->created = true;
+ s->compute_called = false;
+ return s;
+}
+
+static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
+ struct MyCustomKernel* s = static_cast(kernel);
+ s->compute_called = true;
+}
+
+static void MyDeleteFunc(void* kernel) {
+ struct MyCustomKernel* s = static_cast(kernel);
+ EXPECT_TRUE(s->created);
+ EXPECT_TRUE(s->compute_called);
+ delete_called = true;
+ delete s;
+}
+
+// Tests registration of a single C kernel and checks that calls through the
+// C/C++ boundary are being made.
+TEST(TestKernel, TestRegisterKernelBuilder) {
+ const char* kernel_name = "SomeKernelName";
+ const char* op_name = "FooOp";
+ const char* device_name = "barDev";
+
+ TF_KernelBuilder* builder = TF_NewKernelBuilder(
+ op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
+
+ {
+ TF_Status* status = TF_NewStatus();
+ TF_RegisterKernelBuilder(kernel_name, builder, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ ::tensorflow::KernelList list;
+ list.ParseFromArray(buf->data, buf->length);
+ ASSERT_EQ(1, list.kernel_size());
+ ASSERT_EQ("barDev", list.kernel(0).device_type());
+ TF_DeleteBuffer(buf);
+ TF_DeleteStatus(status);
+ }
+
+ REGISTER_OP("FooOp")
+ .Input("input1: double")
+ .Input("input2: uint8")
+ .Output("output1: uint8");
+
+ {
+ ::tensorflow::NodeDef def;
+ def.set_op("FooOp");
+ def.set_device("bar");
+ def.add_input("input1");
+ def.add_input("input2");
+ ::tensorflow::Status status;
+ std::unique_ptr<::tensorflow::OpKernel> kernel =
+ ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"),
+ nullptr, nullptr, def, 1, &status);
+ TF_EXPECT_OK(status);
+ ASSERT_NE(nullptr, kernel.get());
+ kernel->Compute(nullptr);
+ }
+
+ ASSERT_TRUE(delete_called);
+}
diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/c/test_op1.cc
similarity index 68%
rename from tensorflow/core/kernels/captured_function.h
rename to tensorflow/c/test_op1.cc
index 2d2d87134e786139386509c6e5f353bb88882915..b22cc9aef2b344282f45340ff12ee849935a26f9 100644
--- a/tensorflow/core/kernels/captured_function.h
+++ b/tensorflow/c/test_op1.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
-#define TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
-#include "tensorflow/core/kernels/data/captured_function.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
-#endif // TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
+namespace tensorflow {
+
+REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index c18b07603ae3841d3581741ab5a43f2e8b628356..a09becc49b10d2c58f98fbcc11df5190f794c1d4 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -170,6 +170,7 @@ cc_library_with_android_deps(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -488,6 +489,7 @@ tf_gen_op_wrappers_cc(
"image_ops",
"io_ops",
"linalg_ops",
+ "list_ops",
"logging_ops",
"lookup_ops",
"manip_ops",
@@ -516,6 +518,8 @@ tf_gen_op_wrappers_cc(
":array_ops",
":const_op",
":math_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
],
)
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 6abc9e268e3ac97379954a34017ddffa010db67f..81785b2d89b3d36b46992b7ae376b5175a806027 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -95,6 +95,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -112,6 +113,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -135,6 +137,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -167,6 +170,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -183,6 +187,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -200,6 +205,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
kernel_label_(kernel_label),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -217,6 +223,7 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(
clear_colocations
? std::unordered_set()
@@ -237,6 +244,25 @@ Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(assigned_device),
+ xla_cluster_(other.impl()->xla_cluster_),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
+Scope::Impl::Impl(const Scope& other, Tags::XlaCluster,
+ const string& xla_cluster)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
+ xla_cluster_(xla_cluster),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -326,6 +352,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->assigned_device_.empty()) {
builder->AssignedDevice(impl()->assigned_device_);
}
+ if (!impl()->xla_cluster_.empty()) {
+ builder->XlaCluster(impl()->xla_cluster_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -388,7 +417,7 @@ Scope Scope::NewSubScope(const string& child_scope_name) const {
false /* copy_names */));
}
-Scope Scope::WithOpName(const string& op_name) const {
+Scope Scope::WithOpNameImpl(const string& op_name) const {
if (impl()->single_use_scope()) {
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
" on this scope"));
@@ -425,6 +454,10 @@ Scope Scope::WithAssignedDevice(const string& assigned_device) const {
return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
}
+Scope Scope::WithXlaCluster(const string& xla_cluster) const {
+ return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index e307d8989b6647dfac8d2691ed2171c86b7f3a7c..0a75f23725c143e6b22ee6dffae1428ed8209fe8 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -22,6 +22,7 @@ limitations under the License.
#include
#include
+#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -69,8 +70,9 @@ struct CompositeOpScopes;
/// // W will be named "linear/W"
/// auto W = Variable(linear.WithOpName("W"),
/// {2, 2}, DT_FLOAT);
-/// // b will be named "linear/b"
-/// auto b = Variable(linear.WithOpName("b"),
+/// // b will be named "linear/b_3"
+/// int idx = 3;
+/// auto b = Variable(linear.WithOpName("b_", idx),
/// {2}, DT_FLOAT);
/// auto x = Const(linear, {...}); // name: "linear/Const"
/// auto m = MatMul(linear, x, W); // name: "linear/MatMul"
@@ -113,8 +115,11 @@ class Scope {
Scope NewSubScope(const string& child_scope_name) const;
/// Return a new scope. All ops created within the returned scope will have
- /// names of the form `name/op_name[_suffix]`.
- Scope WithOpName(const string& op_name) const;
+ /// names of the form `name/StrCat(fragments...)[_suffix]`
+ template
+ Scope WithOpName(Ty... fragments) const {
+ return WithOpNameImpl(absl::StrCat(fragments...));
+ }
/// Return a new scope. All ops created within the returned scope will have as
/// control dependencies the union of operations in the control_deps vector
@@ -137,6 +142,10 @@ class Scope {
/// their assigned device set to `assigned_device`.
Scope WithAssignedDevice(const string& assigned_device) const;
+ /// Returns a new scope. All ops created within the returned scope will have
+ /// their _XlaCluster attribute set to `xla_cluster`.
+ Scope WithXlaCluster(const string& xla_cluster) const;
+
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
@@ -227,6 +236,8 @@ class Scope {
// END_SKIP_DOXYGEN
private:
+ Scope WithOpNameImpl(const string& op_name) const;
+
friend class InternalScope;
std::unique_ptr impl_;
explicit Scope(Impl*);
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 514e02e84146b6d95147d83182e5d9a07509cfa1..5db7eab2b819c2c5d8fc358953d4607848f1cba5 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -61,6 +61,7 @@ class Scope::Impl {
enum class KernelLabel;
enum class Colocate;
enum class AssignedDevice;
+ enum class XlaCluster;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@@ -78,6 +79,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
+ Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster);
std::unordered_set GetColocationConstraints(
const Operation& colocate_with_op) const;
@@ -112,6 +114,7 @@ class Scope::Impl {
const string kernel_label_ = "";
const string device_ = "";
const string assigned_device_ = "";
+ const string xla_cluster_ = "";
const std::unordered_set colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 3d3895c8fa82c3c0e2974228e9cad767d0e00df4..52345a376cc29ee47ccb9888c9bb26292468b5a9 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -133,5 +133,6 @@ filegroup(
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
+ "testdata/half_plus_two_v2/**",
]),
)
diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h
index 645a3f101d1ae7dda88ec4ca622c694dc5a7a919..6f00dc324bd7054b28de2c35023581e1666bfa01 100644
--- a/tensorflow/cc/saved_model/constants.h
+++ b/tensorflow/cc/saved_model/constants.h
@@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
/// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
-/// SavedModel legacy init op key.
+/// SavedModel legacy init op collection key. Used in v1 SavedModels.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
-/// SavedModel main op key.
+/// SavedModel main op collection key. Used in v1 SavedModels.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
/// Directory in which to save the SavedModel variables.
@@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
/// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables";
+/// SavedModel SignatureDef keys for the initialization and train ops. Used in
+/// V2 SavedModels.
+constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
+constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
+
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e..85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -122,34 +122,54 @@ Status RunOnce(const RunOptions& run_options,
return run_status;
}
-bool HasMainOp(const MetaGraphDef& meta_graph_def) {
+// RunInitOp will return OK if the initialization op was run successfully.
+// An empty init_op_name indicates that there are no init ops to run.
+Status RunInitOp(const RunOptions& run_options, const string& export_dir,
+ const MetaGraphDef& meta_graph_def,
+ const std::vector& asset_file_defs,
+ Session* session, const string& init_op_name) {
+ if (!init_op_name.empty()) {
+ LOG(INFO) << "Running initialization op on SavedModel bundle.";
+ std::vector> inputs;
+ AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
+ RunMetadata run_metadata;
+ return RunOnce(run_options, inputs, {}, {init_op_name},
+ nullptr /* outputs */, &run_metadata, session);
+ }
+ return Status::OK();
+}
+
+// A SavedModel may store the name of the initialization op to run in the
+// in the SignatureDef (v2) or a collection (v1). If an init_op collection
+// exists, then the collection must contain exactly one op.
+Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
+ string* init_op_name) {
+ const auto& sig_def_map = meta_graph_def.signature_def();
+ const auto& init_op_sig_it =
+ meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
+ if (init_op_sig_it != sig_def_map.end()) {
+ *init_op_name = init_op_sig_it->second.outputs()
+ .find(kSavedModelInitOpSignatureKey)
+ ->second.name();
+ return Status::OK();
+ }
+
const auto& collection_def_map = meta_graph_def.collection_def();
+ string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
- return true;
+ init_op_collection_key = kSavedModelMainOpKey;
+ } else {
+ init_op_collection_key = kSavedModelLegacyInitOpKey;
}
- return false;
-}
-Status RunMainOp(const RunOptions& run_options, const string& export_dir,
- const MetaGraphDef& meta_graph_def,
- const std::vector& asset_file_defs,
- Session* session, const string& main_op_key) {
- LOG(INFO) << "Running MainOp with key " << main_op_key
- << " on SavedModel bundle.";
- const auto& collection_def_map = meta_graph_def.collection_def();
- const auto main_op_it = collection_def_map.find(main_op_key);
- if (main_op_it != collection_def_map.end()) {
- if (main_op_it->second.node_list().value_size() != 1) {
+ const auto init_op_it = collection_def_map.find(init_op_collection_key);
+ if (init_op_it != collection_def_map.end()) {
+ if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
- std::vector> inputs;
- AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
- RunMetadata run_metadata;
- const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return RunOnce(run_options, inputs, {}, {string(main_op_name)},
- nullptr /* outputs */, &run_metadata, session);
+ *init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
@@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector* asset_file_defs) {
+ // With SavedModel v2, we write asset file def into metagraph instead of
+ // collection, so read from metagraph first.
+ if (meta_graph_def.asset_file_def_size() > 0) {
+ for (const auto& asset : meta_graph_def.asset_file_def()) {
+ asset_file_defs->push_back(asset);
+ }
+ return Status::OK();
+ }
+ // Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
@@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
- if (HasMainOp(bundle->meta_graph_def)) {
- TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
- bundle->meta_graph_def, asset_file_defs,
- bundle->session.get(), kSavedModelMainOpKey));
- } else {
- TF_RETURN_IF_ERROR(RunMainOp(
- run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
- bundle->session.get(), kSavedModelLegacyInitOpKey));
- }
+ string init_op_name;
+ TF_RETURN_IF_ERROR(
+ GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
+ TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
+ asset_file_defs, bundle->session.get(),
+ init_op_name));
return Status::OK();
}
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2..597e42bb65ab5536664089f7e65ec52d77fc8f23 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
+constexpr char kTestDataInitOpV2[] =
+ "cc/saved_model/testdata/half_plus_two_v2/00000123";
class LoaderTest : public ::testing::Test {
protected:
@@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
}
+TEST_F(LoaderTest, SavedModelInitOpV2Format) {
+ SavedModelBundle bundle;
+ SessionOptions session_options;
+ RunOptions run_options;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
+ TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
+ {kSavedModelTagServe}, &bundle));
+ CheckSavedModelBundle(export_dir, bundle);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f9ff036688007836524129e23f5cf82edd1e8910
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
@@ -0,0 +1 @@
+asset-file-contents
\ No newline at end of file
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..a10bbf8fb6bca0fcee6414b2927d2f706de85ebc
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000000000000000000000000000000000..15b75d6ef6bffc336d138d923badb3928b8c4c13
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..7ec9fb4fe2dd21d0a6c324aecd7658fc37cf2326
Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index differ
diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df80ec01245a7fe820c79d5879458c4cd0a93cb
--- /dev/null
+++ b/tensorflow/compat_template_v1.__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Bring in all of the public TensorFlow interface into this module."""
+
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
+
+import os as _os
+
+# pylint: disable=g-bad-import-order
+from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+
+from tensorflow.python.tools import component_api_helper as _component_api_helper
+_component_api_helper.package_hook(
+ parent_package_str=__name__,
+ child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
+
+# API IMPORTS PLACEHOLDER
+
+from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
+app.flags = flags # pylint: disable=undefined-variable
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 6c29f09cde7ee17c11cb44ce48d8e9128daae4d0..16151e77737429f4fbf690fc34b12a70bacebdc4 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -93,7 +93,7 @@ cc_library(
":tfcompile_lib",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
- "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index b17bc658fa06b9feb7edb292bd89ef31e6309169..e0ac7130a64d3928c39440c0e10a2d2e1990b9cd 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
-Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
+Status GenArgMethods(const tf2xla::Config& config,
+ const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() != num_args) {
@@ -174,7 +175,8 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
}
for (int i = 0; i < num_args; ++i) {
std::vector> rewrites;
- TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
+ TF_RETURN_IF_ERROR(
+ AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
set_arg_data({{I}}, data);
@@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
// Generate methods for results (outputs).
Status GenResultMethods(const tf2xla::Config& config,
- const xla::ProgramShape& ps, string* methods) {
+ const xla::ProgramShapeProto& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// The XlaCompiler we use to build the xla computation always generates a
// tuple result, and we rely on this to simplify code generation.
@@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
}
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector> rewrites;
- TF_RETURN_IF_ERROR(
- AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
+ TF_RETURN_IF_ERROR(AddRewritesForShape(
+ i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
@@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
ExtractEntryParamBufferInfos(buffer_infos);
std::vector buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
- const xla::ProgramShape& ps = compile_result.program_shape;
+ const xla::ProgramShapeProto& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
@@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
@@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
- {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
@@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
Status GenerateMetadata(const CodegenOpts& opts,
const CompileResult& compile_result,
MetadataResult* metadata_result) {
- std::unique_ptr program_shape;
+ std::unique_ptr program_shape;
if (opts.gen_program_shape) {
program_shape =
- absl::make_unique(compile_result.program_shape);
+ absl::make_unique(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
@@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
// a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{
- CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
- program_shape.get()};
+ CreateUniqueIdentifier(opts, "ProgramShapeProto"),
+ "xla::ProgramShapeProto", program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 90410c46a8e36e44454f1219ad76d0fb0937070d..9485e86b10e225a3c9c12eafd9905bdf7c15c9fa 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -57,7 +57,7 @@ struct MetadataResult {
std::vector header_variable_decls;
// program_shape_access_shim is a C++ expression that constructs the
- // xla::ProgramShape instance for the CompileResult passed to
+ // xla::ProgramShapeProto instance for the CompileResult passed to
// GenerateMetadata.
string program_shape_access_shim;
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index bb288d23000527be74f01630d20bbf82e50007ce..c1788ca32a1d099284eeb870f9513891051fd29e 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
- compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
- {
- xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
- xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
- },
- xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
+ compile_result.program_shape =
+ xla::ShapeUtil::MakeProgramShape(
+ {
+ xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
+ xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
+ },
+ xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
+ .ToProto();
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index e4d8a02877c75fa72c5747650ab9c7ac229955b3..a2cdab5d1a8e72504ca11b789287d4efd07a59e9 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -22,7 +22,7 @@ extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters);
-extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
+extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
namespace foo {
@@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = []() {
- xla::ProgramShape* proto = new xla::ProgramShape;
- proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = []() {
+ xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
+ proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
return proto;
}();
return kShape;
diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden
index eb001c5d45bdfefc76629d7303d89f5480432235..ce8e5ec8c96a2c3696f14b8eea206d648182ecb5 100644
Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 2b5f97b34cd928d32eb220536342c715d91d45bb..9fc223bdc7c0e207ce2005cb86250aa77e709df8 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client,
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
- compile_result->program_shape = *pshape_or.ValueOrDie();
- xla::ProgramShape* pshape = &compile_result->program_shape;
- std::vector arg_layouts;
- arg_layouts.reserve(pshape->parameters_size());
+ compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
+ xla::ProgramShapeProto* pshape = &compile_result->program_shape;
+
+ // AotXlaComputationInstance::argument_layouts is a vector of Shape
+ // pointers. Accumulate the Shape objects themselves in a separate vector
+ // while building the vector of pointers.
+ std::vector arg_layout_ptrs(pshape->parameters_size());
+ std::vector arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
- arg_layouts.push_back(pshape->mutable_parameters(i));
+ arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
+ arg_layout_ptrs[i] = &arg_layouts[i];
}
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
- instance.argument_layouts = std::move(arg_layouts);
- instance.result_layout = &pshape->result();
+ instance.argument_layouts = std::move(arg_layout_ptrs);
+ xla::Shape result_shape(pshape->result());
+ instance.result_layout = &result_shape;
xla::StatusOr>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
index e03c5b1aa77c1262ed903aae3072ef65f34d80a2..ee7bb26fabd2d897b85b62f38778ecbfe2238eb6 100644
--- a/tensorflow/compiler/aot/compile.h
+++ b/tensorflow/compiler/aot/compile.h
@@ -33,9 +33,9 @@ namespace tfcompile {
struct CompileResult {
// Contains object file and meta-info.
std::unique_ptr aot;
- xla::ProgramShape program_shape; // Static shape of args and results.
- string entry_point; // Name of generated function.
- int pointer_size = 0; // Size of a pointer in bytes.
+ xla::ProgramShapeProto program_shape; // Static shape of args and results.
+ string entry_point; // Name of generated function.
+ int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph_def into an object file containing a function
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index f10852c7850f61bfd8b99fa9f1648202d182085e..4dd79e5882d7da61be029735ef2b165908c599f9 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) {
// muladd has the program shape defined.
MatMulAndAddComp muladd;
- const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
+ const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
- EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
- EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
+ EXPECT_TRUE(
+ ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
+ EXPECT_TRUE(
+ ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
- const xla::Shape& muladd_result = muladd_shape->result();
+ const xla::Shape muladd_result(muladd_shape->result());
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
const xla::Shape& muladd_result0 =
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 859c84bb91657422b830255b0217f8946d351458..2dc3e8c9113b37bf9d575ad66783f4ab49478af4 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -390,6 +390,7 @@ def target_llvm_triple():
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
+ "//tensorflow:ios": "arm64-none-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index b95b063348c5cdfdcaed635ba527e9f0bfd6092d..d548de8c44285f6d21dd778db464a31e1b19645b 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
- xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
+ xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
@@ -132,7 +132,7 @@ int main(int argc, char** argv) {
std::vector flag_list;
AppendMainFlags(&flag_list, &flags);
- xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
+ xla::AppendDebugOptionsFlags(&flag_list);
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index ced0cd03f74d147451ca2bf54108dc7517b50acd..be91ed4f432b1890c22900f293fd4196e5c9d970 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -21,10 +21,8 @@ package(
)
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -39,7 +37,7 @@ cc_library(
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
- ] + if_cuda_is_configured([
+ ] + if_cuda([
":xla_gpu_device",
":xla_gpu_jit",
]),
@@ -52,6 +50,8 @@ cc_library(
deps = [
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -65,6 +65,7 @@ cc_library(
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
alwayslink = 1,
@@ -75,15 +76,16 @@ cc_library(
srcs = ["xla_cpu_device.cc"],
visibility = [":friends"],
deps = [
+ ":flags",
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
- "//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -101,6 +103,7 @@ cc_library(
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -116,7 +119,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -188,11 +191,13 @@ cc_library(
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
+ "//tensorflow/core/kernels:stack",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
],
)
@@ -205,6 +210,18 @@ cc_library(
# Internal targets below this point.
+cc_library(
+ name = "flags",
+ srcs = ["flags.cc"],
+ hdrs = ["flags.h"],
+ visibility = [":friends"],
+ deps = [
+ "//tensorflow/compiler/xla:parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_library(
name = "common",
srcs = [
@@ -237,6 +254,8 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
@@ -249,6 +268,8 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:debug_options_flags",
+ "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu",
@@ -259,6 +280,22 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "xla_compilation_cache_test",
+ srcs = [
+ "xla_compilation_cache_test.cc",
+ ],
+ deps = [
+ ":xla_compilation_cache",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
],
)
@@ -359,6 +396,83 @@ tf_cc_test(
],
)
+cc_library(
+ name = "shape_inference",
+ srcs = ["shape_inference.cc"],
+ hdrs = ["shape_inference.h"],
+ deps = [
+ ":shape_inference_helpers",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ ":shape_inference",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "shape_inference_test",
+ srcs = ["shape_inference_test.cc"],
+ deps = [
+ ":shape_inference",
+ ":test_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/kernels:constant_op",
+ ],
+)
+
+cc_library(
+ name = "encapsulate_util",
+ srcs = ["encapsulate_util.cc"],
+ hdrs = ["encapsulate_util.h"],
+ deps = [
+ ":shape_inference",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+tf_cc_test(
+ name = "encapsulate_util_test",
+ srcs = ["encapsulate_util_test.cc"],
+ deps = [
+ ":encapsulate_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "compilation_passes",
srcs = [
@@ -367,6 +481,8 @@ cc_library(
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"encapsulate_xla_computations_pass.cc",
+ "extract_outside_compilation_pass.cc",
+ "increase_dynamism_for_auto_jit_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -376,12 +492,16 @@ cc_library(
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
+ "extract_outside_compilation_pass.h",
+ "increase_dynamism_for_auto_jit_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
],
deps = [
":common",
+ ":encapsulate_util",
+ ":flags",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
@@ -389,13 +509,13 @@ cc_library(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
- "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags",
- "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
+ "//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
@@ -409,8 +529,10 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -435,25 +557,6 @@ cc_library(
hdrs = ["union_find.h"],
)
-cc_library(
- name = "producer_consumer_queue",
- hdrs = ["producer_consumer_queue.h"],
- deps = ["//tensorflow/core:lib"],
-)
-
-tf_cc_test(
- name = "producer_consumer_queue_test",
- size = "small",
- srcs = ["producer_consumer_queue_test.cc"],
- deps = [
- ":producer_consumer_queue",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ],
-)
-
tf_cc_test(
name = "deadness_analysis_test",
size = "small",
@@ -491,12 +594,15 @@ tf_cc_test(
"build_xla_ops_pass_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
+ "extract_outside_compilation_pass_test.cc",
+ "increase_dynamism_for_auto_jit_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
deps = [
":common",
":compilation_passes",
+ ":encapsulate_util",
":node_matchers",
":xla_cluster_util",
":xla_cpu_device",
@@ -506,17 +612,21 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
@@ -553,31 +663,6 @@ tf_cc_test(
],
)
-tf_cc_test(
- name = "xla_launch_util_test",
- size = "small",
- srcs = ["xla_launch_util_test.cc"],
- deps = [
- ":common",
- ":xla_compilation_cache",
- ":xla_launch_util",
- ":xla_tensor",
- "//tensorflow/compiler/tf2xla:common",
- "//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:gpu_runtime",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core/kernels:variable_ops",
- ],
-)
-
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
@@ -652,7 +737,10 @@ tf_custom_op_py_library(
visibility = [
":friends",
],
- deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+ deps = [
+ "//tensorflow/compiler/jit/ops:xla_ops_grad",
+ "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
+ ],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 054f31ba3352b2215e6b0448c8ec8a70cb98b8e5..9f4042630edaec1b9519b6434d859a48372e8b15 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
+#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
- *result = registration->requires_compilation;
+ *result = registration->autoclustering_policy ==
+ XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK();
}
@@ -319,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
return IsXlaCompiledKernel(*n);
});
- bool lazy_compilation_enabled = enable_lazy_compilation_
- ? *enable_lazy_compilation_
- : legacy_flags::GetBuildXlaOpsPassFlags()
- .tf_xla_enable_lazy_compilation;
+ bool lazy_compilation_enabled =
+ enable_lazy_compilation_
+ ? *enable_lazy_compilation_
+ : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index 11df946cc186660242574c2644463a26ead44f1f..48a23a4c1711ac88a329723c46559112d5a39dbd 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
.ok());
}
- void TearDown() override {
- for (Device* device : devices_) {
- delete device;
- }
- }
-
private:
- std::vector devices_;
+ std::vector> devices_;
};
using ::tensorflow::testing::FindNodeByName;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
index 73866607621cd745f6e640a14405daebf0dd9985..0f872a480f4d4843217f1df3452c4dc62531264e 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
+ std::vector> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
+ options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) {
@@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
lib_def_ = absl::make_unique(
OpRegistry::Global(), proto);
OptimizerOptions opts;
- device_mgr_ = absl::make_unique(devices_);
+ device_mgr_ = absl::make_unique(std::move(devices));
pflr_ = absl::make_unique(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
}
FunctionLibraryRuntime* flr_;
- std::vector devices_;
std::unique_ptr device_mgr_;
std::unique_ptr lib_def_;
std::unique_ptr pflr_;
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index b7ae7fbeb3912882368dc828e8d6fcd50735b04e..0562838f628c66b1eb03af9d2a5139c01dca31c5 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -525,7 +525,6 @@ Predicate* PredicateFactory::MakeAndOrImpl(
op->GetOperands().begin(),
op->GetOperands().end());
} else {
- std::vector sub_ops_intersection;
common_inner_operands.clear();
absl::c_copy_if(op->GetOperands(),
std::back_inserter(common_inner_operands),
@@ -696,8 +695,8 @@ Status CreateMultipleNextIterationInputsError(Node* merge) {
}
}
return errors::InvalidArgument(
- "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge),
- ": \n", absl::StrJoin(backedges, "\n"),
+ "Multiple NextIteration inputs to merge node ",
+ FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
"\nMerge nodes can have at most one incoming NextIteration edge.");
}
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 617e31488c7daeb714c0ff7056b786e4eaf7873f..8a73101c184e6190921fd7729742922bd96f4bcf 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root,
Output loop_cond =
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
- ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
+ ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
+ latch.output_false);
Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
latch.output_true, increment_by);
Output next_iteration =
@@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue(
value, frame_name);
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
- ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
+ ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
+ latch.output_false);
Output next_iteration = ops::NextIteration(
root.WithOpName(prefix + "/next_iteration"), latch.output_true);
CHECK(root.graph()
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index da030b3bcc7aacae2306bec30f4b8927aa042d7c..f478832781cb1dc045d9163d4a6f5e5f64a8a705 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -1122,8 +1122,11 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
fdef);
}
- if (!reuse_existing_functions || library->Find(name) == nullptr) {
+ const FunctionDef* original_fdef = library->Find(name);
+ if (!reuse_existing_functions || original_fdef == nullptr) {
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
+ TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 49958093b8dcf35e8adcdfd2f7dfce8558d5db6f..de89be9a3555960dabe7bacd17226c15ae888ae6 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -16,16 +16,20 @@ limitations under the License.
#include
#include
-#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -406,8 +410,8 @@ Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) {
Node* KeyPlaceholder(const string& call_node,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
- NodeBuilder node_builder(opts.GetNameForOp("Placeholder"), "Placeholder",
- opts.op_registry());
+ NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"),
+ "Placeholder", opts.op_registry());
TensorShapeProto shape;
shape.add_dim()->set_size(2);
return opts.WithAttr("shape", shape)
@@ -494,7 +498,8 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return opts.FinalizeBuilder(&node_builder);
}
-Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
+Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
+ const std::vector& encapsulated_functions) {
Status s;
// Convert the GraphDef to a Graph
std::unique_ptr lib_def(
@@ -505,11 +510,39 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
if (!s.ok()) return s;
+ s = PerformStaticShapeInferenceBeforeEncapsulation(
+ graph.get(), "_encapsulate", "_outside");
+ if (!s.ok()) return s;
+
+ s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside");
+ if (!s.ok()) return s;
+
std::unique_ptr graph_out;
- s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
- /*rewrite_subgraph_fn=*/{},
- /*reuse_existing_functions=*/false,
- &graph_out, lib_def.get());
+ s = EncapsulateSubgraphsInFunctions(
+ "_encapsulate", /*outside_compilation_attribute=*/"", *graph,
+ /*rewrite_subgraph_fn=*/{},
+ /*reuse_existing_functions=*/false, &graph_out, lib_def.get());
+ if (!s.ok()) return s;
+
+ std::unordered_map clusters;
+ for (const auto& func : encapsulated_functions) {
+ Node* xla_computation_node;
+ for (Node* n : graph_out->nodes()) {
+ if (n->name() == func) {
+ xla_computation_node = n;
+ }
+ }
+ if (!xla_computation_node) {
+ return errors::Internal("Cannot find node ", func);
+ }
+ NameAttrList func_name_attrs;
+ func_name_attrs.set_name(func);
+ clusters.emplace(func,
+ XlaClusterInfo{func, func_name_attrs, xla_computation_node,
+ std::map{}});
+ }
+ s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
+ graph_out.get(), lib_def.get());
if (!s.ok()) return s;
GraphDef graphdef_out;
@@ -520,6 +553,11 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
return s;
}
+Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
+ std::vector encapsulated_functions;
+ return Encapsulate(graphdef, library, encapsulated_functions);
+}
+
// If there are no marked nodes, funcification should be a no-op.
TEST(EncapsulateSubgraphsTest, NoFunctions) {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
@@ -703,7 +741,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
- "_cluster", "_outside", graph_before_encapsulation,
+ "_cluster", "", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph, &library));
@@ -755,7 +793,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
- "_encapsulate", "_outside", graph_before,
+ "_encapsulate", "", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector& arg_source_tensors,
std::unique_ptr* graph_ptr,
@@ -800,7 +838,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
- "_encapsulate", "_outside", graph_before,
+ "_encapsulate", "", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector& arg_source_tensors,
std::unique_ptr* graph_ptr,
@@ -854,15 +892,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape.opts());
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT, DT_FLOAT}, shape.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
@@ -877,7 +915,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
@@ -899,7 +937,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -975,15 +1013,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT, DT_FLOAT}, shape1.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
@@ -998,8 +1036,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape2.opts());
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT, DT_FLOAT}, shape2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
@@ -1020,7 +1057,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
}
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
@@ -1037,14 +1074,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"F:o:0", "D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
- {"ancestors",
- absl::Span({"outside_compilation_O1_host_compute"})},
+ {"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
- {"F", "outside_compilation_O1_host_compute"}},
+ {"F"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
@@ -1058,7 +1094,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"i_0_retval", "I:o:0"}});
+ {{"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1149,33 +1185,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1", "F2"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
- {
- GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape.opts());
- Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- shape.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
- TF_EXPECT_OK(
- AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
- }
-
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
- {"f_0_retval:float", "d_0_retval:float"}, {},
+ {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1191,19 +1212,19 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
- {"shapes", absl::Span({})},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
+ {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
- "F2", {"e_0_arg:float", "f_0_arg:float"},
- {"g_0_retval:float", "i_0_retval:float"}, {},
+ "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"},
+ {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {},
{
- {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
+ {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
@@ -1219,7 +1240,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
+ {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}});
{
std::unique_ptr lib_def(
@@ -1265,11 +1286,11 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
- node_builder2.Input(e).Input(call1);
+ node_builder2.Input(call1).Input(e);
Node* call2 = b2.opts()
.WithControlInputs({s2, e, call1})
.FinalizeBuilder(&node_builder2);
- Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
+ Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1312,44 +1333,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1", "F2"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
-
- {
- GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape.opts());
- Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- shape.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
- TF_EXPECT_OK(
- AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
- }
-
- {
- GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F2", "O1",
- {DT_FLOAT}, shape.opts());
- Node* h = Unary(recv, shape.opts()
- .WithName("H")
- .WithAttr("_encapsulate", "F2")
- .WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F2", "O1", {h}, shape.opts());
- TF_EXPECT_OK(
- AddGraphDefToFunctionLibrary(shape, "F2_O1", &library_expected));
- }
+ TensorShapeProto shape_proto_expected;
+ shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1365,16 +1358,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
- {"shapes", absl::Span({})},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
- "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
+ "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}},
{{"I"},
@@ -1387,12 +1380,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F2_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F2_O1"},
- {"shapes", absl::Span({})},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"i_0_retval", "I:o:0"}});
+ {{"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1439,9 +1432,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(a).Input(b);
- Node* call2 = b2.opts()
- .WithControlInputs({s2, call1})
- .FinalizeBuilder(&node_builder2);
+ Node* call2 =
+ b2.opts().WithControlInputs({s2}).FinalizeBuilder(&node_builder2);
Binary(call1, call2, b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1473,7 +1465,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
@@ -1482,7 +1475,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1501,7 +1494,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1557,7 +1550,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
@@ -1566,7 +1560,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1586,7 +1580,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1644,13 +1638,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1666,7 +1661,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1721,13 +1716,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1747,7 +1743,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1811,15 +1807,15 @@ TEST(EncapsulateSubgraphsTest,
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape2.opts());
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT}, shape2.opts());
Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
@@ -1832,7 +1828,7 @@ TEST(EncapsulateSubgraphsTest,
}
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1852,7 +1848,7 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}}},
},
- {{"h_0_retval", "H:o:0"}});
+ {{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -1920,15 +1916,15 @@ TEST(EncapsulateSubgraphsTest,
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, shape1.opts());
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
@@ -1941,7 +1937,7 @@ TEST(EncapsulateSubgraphsTest,
}
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1961,7 +1957,7 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"h_0_retval", "H:o:0"}});
+ {{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -2034,15 +2030,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, shape1.opts());
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
@@ -2055,7 +2051,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
}
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
{{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
@@ -2076,28 +2072,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({})},
- {"ancestors",
- absl::Span({"outside_compilation_O1_host_compute"})},
+ {"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph", ""},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
- {"outside_compilation_O1_host_compute"}},
+ {}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({})},
- {"ancestors",
- absl::Span({"outside_compilation_O1_host_compute",
- "outside_compilation_O2_host_compute"})},
+ {"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O3"},
{"shape_inference_graph", ""},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O3"}},
- {"outside_compilation_O1_host_compute",
- "outside_compilation_O2_host_compute"}}},
- {{"h_0_retval", "H:o:0"}});
+ {}}},
+ {{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -2169,19 +2161,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -2234,19 +2227,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
- TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+ std::vector encapsulated_functions{"F1"};
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
- Node* key_constant =
- KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1"));
+ Node* key_constant = KeyPlaceholder("F1", shape.opts());
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, shape.opts());
- Node* e = BinaryUnknownShape(known, recv,
+ Node* a = InputShaped(shape.opts().WithName("A"));
+ Node* c = Unary(a, shape.opts().WithName("C"));
+ Node* e = BinaryUnknownShape(c, recv,
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
@@ -2258,7 +2252,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
+ "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
{{"F"},
@@ -2279,7 +2273,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
- {{"f_0_retval", "F:o:0"}});
+ {{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1f4b9c90a4ff0b1166cdb7b5942771b350740ef3
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_util.cc
@@ -0,0 +1,955 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include
+#include
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/jit/shape_inference.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Returns string attribute value for the node if the attribute is present,
+// otherwise returns empty optional value.
+absl::optional GetStringAttr(const Node& n, const string& attr_name) {
+ auto attr = n.attrs().Find(attr_name);
+ if (!attr) {
+ return absl::nullopt;
+ } else {
+ return attr->s();
+ }
+}
+
+// Adds a value to the node's list attribute.
+template
+Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
+ std::vector attr_value;
+ Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
+ if (!s.ok() && s.code() != error::NOT_FOUND) {
+ return s;
+ }
+
+ n->ClearAttr(attr_name);
+ attr_value.push_back(value);
+ n->AddAttr(attr_name, attr_value);
+ return Status::OK();
+}
+
+// Replaces attribute value.
+template
+void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
+ n->ClearAttr(attr_name);
+ n->AddAttr(attr_name, value);
+}
+
+// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of
+// PreprocessForEncapsulation() for details.
+Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name) {
+ // Gather edges to remove. We should not remove the edge while iterating.
+ std::vector edges_to_remove;
+ for (const Edge* e : g->edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_xla_computation =
+ GetStringAttr(*e->src(), xla_computation_attr_name);
+ auto dst_xla_computation =
+ GetStringAttr(*e->dst(), xla_computation_attr_name);
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (!src_xla_computation && !dst_xla_computation) {
+ continue;
+ } else if (src_xla_computation && !dst_xla_computation) {
+ if (src_outside_compilation) {
+ // Case 1c: outside compilation to host computation control edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
+ }
+ } else if (!src_xla_computation && dst_xla_computation) {
+ if (dst_outside_compilation) {
+ // Case 1c: host computation control to outside compilation edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
+ }
+ } else { // src_xla_computation && dst_xla_computation
+ if (*src_xla_computation != *dst_xla_computation) {
+ if (src_outside_compilation && dst_outside_compilation) {
+ // Case 1b: outside compilation to outside compilation control edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
+ } else if (src_outside_compilation && !dst_outside_compilation) {
+ // Case 1a: outside compilation to another XLA computaition control
+ // edge.
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->src(), kXlaConnectedToOtherXlaComputationAttrName,
+ *dst_xla_computation));
+ } else if (!src_outside_compilation && dst_outside_compilation) {
+ // Case 1a: another XLA computaition to outside compilation control
+ // edge.
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
+ *src_xla_computation));
+ }
+ }
+ }
+ }
+
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+ return Status::OK();
+}
+
+// Step 2 for PreprocessForEncapsulation(). See comments of
+// PreprocessForEncapsulation() for details.
+Status ProcessXlaToXlaDataEdges(Graph* g,
+ const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name) {
+ // Gather edges between XLA computations. Notice that we do not store `Edge*`
+ // directly because we remove some nodes while adding Identity nodes, and
+ // those Edge pointers might be invalidated.
+ struct EdgeInfo {
+ int dst_input, dst_node_id;
+ };
+ std::vector edges;
+ for (const Edge* e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_xla_computation =
+ GetStringAttr(*e->src(), xla_computation_attr_name);
+ auto dst_xla_computation =
+ GetStringAttr(*e->dst(), xla_computation_attr_name);
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+ if (!src_xla_computation || !dst_xla_computation) {
+ continue;
+ }
+
+ if (*src_xla_computation != *dst_xla_computation) {
+ if (src_outside_compilation || dst_outside_compilation) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
+ VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
+ }
+ }
+ }
+
+ // For each XLA -> XLA edge, add an Identity node between src and dst.
+ for (int i = 0; i < edges.size(); i++) {
+ Node* dst = g->FindNodeId(edges[i].dst_node_id);
+ const Edge* e;
+ TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
+ Node* src = e->src();
+ int src_output = e->src_output(), dst_input = e->dst_input();
+ g->RemoveEdge(e);
+
+ // Create Identity node, and connect it between `src` and `dst`.
+ string identity_node_name =
+ absl::StrCat("bridge_", src->name(), "_", dst->name());
+ DataType dtype = src->output_type(src_output);
+ TF_ASSIGN_OR_RETURN(Node * identity_node,
+ BuildIdentityNode(g, identity_node_name, dtype, src,
+ /*requested_device=*/absl::nullopt));
+ identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name());
+ g->AddEdge(src, src_output, identity_node, 0);
+ g->AddEdge(identity_node, 0, dst, dst_input);
+
+ // Replace `e->dst()` because its input node changed.
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(dst_input) = identity_node->name();
+ TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
+
+ // Other edge in `edges` might have `e->dst()` as src or dst
+ // node. Before removing `e->dst()`, replace those edges with corresponding
+ // edges for `dst_replace_node`.
+ for (int j = i + 1; j < edges.size(); j++) {
+ if (edges[j].dst_node_id == edges[i].dst_node_id) {
+ edges[j].dst_node_id = dst_replace_node->id();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// Step 3 for PreprocessForEncapsulation(). See comments of
+// PreprocessForEncapsulation() for details.
+Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
+ Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name) {
+ // Gather edges between outside compilation and host computation. Notice that
+ // we do not store `Edge*` directly because we remove some nodes while adding
+ // Identity nodes, and those Edge pointers might be invalidated.
+ struct EdgeInfo {
+ int dst_input, dst_node_id;
+ bool is_host_to_outside_compilation;
+ };
+ std::vector edges;
+ for (const Edge* e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr &&
+ e->dst()->attrs().Find(xla_computation_attr_name) != nullptr &&
+ e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
+ /*is_host_to_outside_compilation=*/true});
+ VLOG(4) << "Host -> oc edge: " << e->DebugString();
+ } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr &&
+ e->src()->attrs().Find(xla_computation_attr_name) != nullptr &&
+ e->src()->attrs().Find(outside_compilation_attr_name) !=
+ nullptr) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
+ /*is_host_to_outside_compilation=*/false});
+ VLOG(4) << "Oc -> host edge: " << e->DebugString();
+ }
+ }
+
+ // Remove the edge from host to outside compilation. Add a placeholder as
+ // outside compilation node input.
+ std::map, Node*> placeholders;
+ for (int i = 0; i < edges.size(); i++) {
+ Node* dst = g->FindNodeId(edges[i].dst_node_id);
+ const Edge* e;
+ TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
+ Node* src = e->src();
+ int src_output = e->src_output(), dst_input = e->dst_input();
+ g->RemoveEdge(e);
+
+ // Find or create placeholder node.
+ string new_name =
+ edges[i].is_host_to_outside_compilation
+ ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
+ : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
+ auto placeholder_index = std::make_pair(src->name(), src_output);
+ auto iter = placeholders.find(placeholder_index);
+ Node* placeholder_node;
+ if (iter == placeholders.end()) {
+ NodeDefBuilder placeholder_builder(new_name, "Placeholder");
+ placeholder_builder.Attr("dtype", src->output_type(src_output));
+ if (edges[i].is_host_to_outside_compilation) {
+ placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName,
+ src->name());
+ placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName,
+ src_output);
+ // If this placeholder node is in outside compilation, we need to set
+ // `xla_computation_attr_name` and `outside_compilation_attr_name`.
+ string xla_computation_attr, outside_compilation_attr;
+ TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name,
+ &xla_computation_attr));
+ TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
+ outside_compilation_attr_name,
+ &outside_compilation_attr));
+ placeholder_builder.Attr(xla_computation_attr_name,
+ xla_computation_attr);
+ placeholder_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_attr);
+ } else {
+ placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName,
+ src->name());
+ placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName,
+ src_output);
+ }
+ NodeDef placeholder_def;
+ TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
+ Status s;
+ placeholder_node = g->AddNode(placeholder_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ placeholders[placeholder_index] = placeholder_node;
+ } else {
+ placeholder_node = iter->second;
+ }
+ g->AddEdge(placeholder_node, 0, dst, dst_input);
+
+ // Replace `e->dst()` because its input node changed.
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(dst_input) = placeholder_node->name();
+ TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
+
+ // Other edge in `edges` might have `e->dst()` as src or dst
+ // node. Before removing `e->dst()`, replace those edges with corresponding
+ // edges for `dst_replace_node`.
+ for (int j = i + 1; j < edges.size(); j++) {
+ if (edges[j].dst_node_id == edges[i].dst_node_id) {
+ edges[j].dst_node_id = dst_replace_node->id();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// Step 1 for `PostprocessForEncapsulation`. See comments of
+// `PostprocessForEncapsulation` for details.
+Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) {
+ // Gather all outside compilation to host computation nodes.
+ struct PlaceHolderNodeInfo {
+ Node* n;
+ bool is_host_to_oc;
+ };
+ std::vector placeholder_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "Placeholder") {
+ if (HasNodeAttr(n->def(),
+ kOutsideCompilationToHostOriginalNodeAttrName)) {
+ placeholder_nodes.push_back({n, false});
+ } else if (HasNodeAttr(n->def(),
+ kHostToOutsideCompilationOriginalNodeAttrName)) {
+ placeholder_nodes.push_back({n, true});
+ }
+ }
+ }
+
+ // Remove the placeholder nodes, and reconnect original edge.
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto placeholder_iter : placeholder_nodes) {
+ Node* n = placeholder_iter.n;
+
+ string node_name;
+ int node_src_output;
+ if (placeholder_iter.is_host_to_oc) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName,
+ &node_name));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
+ kHostToOutsideCompilationSrcOutputAttrName,
+ &node_src_output));
+ } else {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName,
+ &node_name));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
+ kOutsideCompilationToHostSrcOutputAttrName,
+ &node_src_output));
+ }
+ auto iter = node_name_index.find(node_name);
+ if (iter == node_name_index.end()) {
+ return errors::Internal(
+ "Cannot find original node for oc -> host placeholder node ",
+ node_name);
+ }
+
+ // Change all usage node to use the original node instead.
+ Node* original_node = iter->second;
+ std::vector control_edges;
+ std::vector data_edges;
+ for (auto e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges.push_back(e);
+ } else {
+ data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
+ }
+ }
+ for (const Edge* e : control_edges) {
+ g->AddControlEdge(original_node, e->dst());
+ g->RemoveEdge(e);
+ }
+ for (int i = 0; i < data_edges.size(); i++) {
+ Node* dst = data_edges[i].dst;
+ NodeDef new_def = dst->def();
+ int dst_input = data_edges[i].dst_input;
+ *new_def.mutable_input(dst_input) =
+ absl::StrCat(original_node->name(), ":", node_src_output);
+ TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
+
+ const Edge* edge_to_replace = nullptr;
+ TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
+ g->RemoveEdge(edge_to_replace);
+ g->AddEdge(original_node, node_src_output, replace_node, dst_input);
+
+ // Other edges might have `dst` as dst node. Update those edges with
+ // `replace_node`.
+ for (int j = i + 1; j < data_edges.size(); j++) {
+ if (data_edges[j].dst == dst) {
+ data_edges[j].dst = replace_node;
+ }
+ }
+
+ // Other placeholder node might have `dst` as original node. Update
+ // `node_name_index` with `replace_node`.
+ node_name_index[replace_node->name()] = replace_node;
+ }
+
+ // Remove placeholder node.
+ g->RemoveNode(n);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PostprocessForEncapsulation`. See comments of
+// `PostprocessForEncapsulation` for details.
+Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) {
+ // Gather Identity nodes to remove.
+ std::vector bridge_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "Identity" &&
+ HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) {
+ bridge_nodes.push_back(n);
+ }
+ }
+
+ // Remove the identity nodes, and reconnect the original edge.
+ for (int i = 0; i < bridge_nodes.size(); i++) {
+ Node* n = bridge_nodes[i];
+ const Edge* src_edge = nullptr;
+ TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge));
+
+ // Change all usage node to use the original node instead.
+ std::vector control_edges;
+ std::vector data_edges;
+ for (auto e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges.push_back(e);
+ } else {
+ data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
+ }
+ }
+ for (const Edge* e : control_edges) {
+ g->AddControlEdge(src_edge->src(), e->dst());
+ g->RemoveEdge(e);
+ }
+ for (int j = 0; j < data_edges.size(); j++) {
+ Node* dst = data_edges[j].dst;
+ NodeDef new_def = dst->def();
+ int dst_input = data_edges[j].dst_input;
+ *new_def.mutable_input(dst_input) =
+ absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output());
+ TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
+
+ const Edge* edge_to_replace = nullptr;
+ TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
+ g->RemoveEdge(edge_to_replace);
+ g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node,
+ dst_input);
+
+ // Other edges might have `dst` as dst node. Update those edges with
+ // `replace_node`.
+ for (int k = j + 1; k < data_edges.size(); k++) {
+ if (data_edges[k].dst == dst) {
+ data_edges[k].dst = replace_node;
+ }
+ }
+
+ // The node we replaced might be in `bridge_nodes`. If so, update
+ // `bridge_nodes` to use the replaced node.
+ for (int k = i + 1; k < bridge_nodes.size(); k++) {
+ if (bridge_nodes[k] == dst) {
+ bridge_nodes[k] = replace_node;
+ }
+ }
+ }
+
+ // Remove Identity node.
+ g->RemoveNode(n);
+ }
+ return Status::OK();
+}
+
+// Step 3 for `PostprocessForEncapsulation`. See comments of
+// `PostprocessForEncapsulation` for details.
+// We do not need to worry about removed nodes in step 1 and 2;
+// `PreprocessForEncapsulation` will not record control dependencies for those
+// remvoed nodes in the first place.
+Status AddControlDependencies(
+ Graph* g, const std::unordered_map& cluster_node_names) {
+ auto node_name_index = g->BuildNodeNameIndex();
+
+ // Reconnect outside compilation to outside compilation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector control_deps;
+ Status s =
+ GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaControlDependenciesAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = node_name_index.find(control_input);
+ if (iter == node_name_index.end()) {
+ return errors::Internal("Cannot find original node for ",
+ control_input);
+ }
+ g->AddControlEdge(iter->second, n);
+ }
+ }
+ }
+
+ // Reconnect outside compilation to XLA computation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector control_deps;
+ Status s = GetNodeAttr(
+ n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = cluster_node_names.find(control_input);
+ if (iter == cluster_node_names.end()) {
+ return errors::Internal("Cannot find cluster node for ",
+ control_input);
+ }
+ auto iter2 = node_name_index.find(iter->second);
+ if (iter2 == node_name_index.end()) {
+ return errors::Internal("Cannot find cluster node for ",
+ iter->second);
+ }
+ g->AddControlEdge(n, iter2->second);
+ }
+ }
+ }
+
+ // Reconnect XLA computation to outside compilation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector control_deps;
+ Status s =
+ GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName,
+ &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = cluster_node_names.find(control_input);
+ if (iter == cluster_node_names.end()) {
+ return errors::Internal("Cannot find cluster node for ",
+ control_input);
+ }
+ auto iter2 = node_name_index.find(iter->second);
+ if (iter2 == node_name_index.end()) {
+ return errors::Internal("Cannot find cluster node for ",
+ iter->second);
+ }
+ g->AddControlEdge(iter2->second, n);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges to remove. We should not remove the edge while iterating.
+ std::vector edges_to_remove;
+ for (const Edge* e : g->edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation) {
+ if (*src_outside_compilation != *dst_outside_compilation) {
+ // Case 1a: outside compilation to outside compilation control edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr(
+ e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ e->src()->name()));
+ }
+ } else if (src_outside_compilation && !dst_outside_compilation) {
+ // Case 1b: outside compilation to its XLA computation control edge.
+ ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
+ } else if (!src_outside_compilation && dst_outside_compilation) {
+ // Case 1b: XLA computation to outside compilation in it control edge.
+ ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
+ }
+ }
+
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges between outside compilation and host computation. Notice that
+ // we do not store `Edge*` directly because we remove some nodes while adding
+ // Identity nodes, and those Edge pointers might be invalidated.
+ struct EdgeInfo {
+ int dst_input, dst_node_id;
+ };
+ std::vector edges;
+ for (const Edge* e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation &&
+ *src_outside_compilation != *dst_outside_compilation) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
+ VLOG(4) << "Oc -> oc edge: " << e->DebugString();
+ }
+ }
+
+ // Remove the edge from host to outside compilation. Add a placeholder as
+ // outside compilation node input.
+ std::map, Node*> placeholders;
+ for (int i = 0; i < edges.size(); i++) {
+ Node* dst = g->FindNodeId(edges[i].dst_node_id);
+ const Edge* e;
+ TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
+ Node* src = e->src();
+ int src_output = e->src_output(), dst_input = e->dst_input();
+ g->RemoveEdge(e);
+
+ // Find or create placeholder node.
+ string new_name =
+ absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
+ auto placeholder_index = std::make_pair(src->name(), src_output);
+ auto iter = placeholders.find(placeholder_index);
+ Node* placeholder_node;
+ if (iter == placeholders.end()) {
+ NodeDefBuilder placeholder_builder(new_name, "Placeholder");
+ placeholder_builder.Attr("dtype", src->output_type(src_output));
+ string outside_compilation_attr;
+ TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
+ outside_compilation_attr_name,
+ &outside_compilation_attr));
+ placeholder_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_attr);
+ placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
+ src->name());
+ placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
+ src_output);
+ NodeDef placeholder_def;
+ TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
+ Status s;
+ placeholder_node = g->AddNode(placeholder_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ placeholders[placeholder_index] = placeholder_node;
+ } else {
+ placeholder_node = iter->second;
+ }
+ g->AddEdge(placeholder_node, 0, dst, dst_input);
+
+ // Replace `e->dst()` because its input node changed.
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(dst_input) = placeholder_node->name();
+ TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
+
+ // Other edge in `edges` might have `e->dst()` as src or dst
+ // node. Before removing `e->dst()`, replace those edges with
+ // corresponding edges for `dst_replace_node`.
+ for (int j = i + 1; j < edges.size(); j++) {
+ if (edges[j].dst_node_id == edges[i].dst_node_id) {
+ edges[j].dst_node_id = dst_replace_node->id();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather all outside compilation to outside compilation nodes.
+ std::vector placeholder_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "Placeholder" &&
+ HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
+ placeholder_nodes.push_back(n);
+ }
+ }
+
+ // Remove the placeholder nodes, and reconnect original edge.
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto n : placeholder_nodes) {
+ string node_name;
+ int node_src_output;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
+ auto iter = node_name_index.find(node_name);
+ if (iter == node_name_index.end()) {
+ return errors::Internal(
+ "Cannot find original node for oc -> host placeholder node ",
+ node_name);
+ }
+
+ // Change all usage node to use the original node instead.
+ Node* original_node = iter->second;
+ std::vector control_edges;
+ std::vector data_edges;
+ for (auto e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges.push_back(e);
+ } else {
+ data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
+ }
+ }
+ for (const Edge* e : control_edges) {
+ g->AddControlEdge(original_node, e->dst());
+ g->RemoveEdge(e);
+ }
+ for (int i = 0; i < data_edges.size(); i++) {
+ Node* dst = data_edges[i].dst;
+ NodeDef new_def = dst->def();
+ int dst_input = data_edges[i].dst_input;
+ *new_def.mutable_input(dst_input) =
+ absl::StrCat(original_node->name(), ":", node_src_output);
+ TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
+
+ const Edge* edge_to_replace = nullptr;
+ TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
+ g->RemoveEdge(edge_to_replace);
+ g->AddEdge(original_node, node_src_output, replace_node, dst_input);
+
+ // Other edges might have `dst` as dst node. Update those edges with
+ // `replace_node`.
+ for (int j = i + 1; j < data_edges.size(); j++) {
+ if (data_edges[j].dst == dst) {
+ data_edges[j].dst = replace_node;
+ }
+ }
+
+ // Other placeholder node might have `dst` as original node. Update
+ // `node_name_index` with `replace_node`.
+ node_name_index[replace_node->name()] = replace_node;
+ }
+
+ // Remove placeholder node.
+ g->RemoveNode(n);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ auto node_name_index = g->BuildNodeNameIndex();
+
+ // Reconnect outside compilation to outside compilation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector control_deps;
+ Status s =
+ GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = node_name_index.find(control_input);
+ if (iter == node_name_index.end()) {
+ return errors::Internal("Cannot find original node for ",
+ control_input);
+ }
+ g->AddControlEdge(iter->second, n);
+ }
+ }
+ }
+ return Status::OK();
+}
+} // namespace
+
+const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
+
+const char kXlaConnectedToOtherXlaComputationAttrName[] =
+ "_xla_connected_to_other_xla_computation";
+const char kXlaConnectedFromOtherXlaComputationAttrName[] =
+ "_xla_connected_from_other_xla_computation";
+const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies";
+const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src";
+const char kOutsideCompilationToHostOriginalNodeAttrName[] =
+ "_xla_oc_to_host_node_name";
+const char kOutsideCompilationToHostSrcOutputAttrName[] =
+ "_xla_oc_to_host_src_output";
+const char kHostToOutsideCompilationOriginalNodeAttrName[] =
+ "_xla_host_to_oc_node_name";
+const char kHostToOutsideCompilationSrcOutputAttrName[] =
+ "_xla_host_to_oc_src_output";
+const char kXlaConnectedToXlaComputationAttrName[] =
+ "_xla_connected_to_xla_computation";
+const char kXlaConnectedFromXlaComputationAttrName[] =
+ "_xla_connected_from_xla_computation";
+const char kOutsideCompilationOriginalNodeAttrName[] =
+ "_xla_oc_to_oc_node_name";
+const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
+const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
+ "_xla_control_dependencies_within_xla_cluster";
+
+Status PerformStaticShapeInferenceBeforeEncapsulation(
+ Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name) {
+ // Find all outside compilation to XLA computation data edges.
+ std::unordered_set outside_compilation_send_nodes;
+ for (auto e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name);
+ auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name);
+ if (!src_computation || !dst_computation ||
+ *src_computation != *dst_computation) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+ if (src_outside_compilation && !dst_outside_compilation) {
+ outside_compilation_send_nodes.insert(e->src());
+ }
+ }
+
+ // Perform shape inference.
+ std::map arg_shapes;
+ GraphShapeInfo shape_info;
+ TF_RETURN_IF_ERROR(
+ InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
+
+ // Add attribute for output shapes.
+ for (Node* n : outside_compilation_send_nodes) {
+ auto iter = shape_info.find(n->name());
+ if (iter == shape_info.end()) {
+ continue;
+ }
+
+ std::vector output_shapes;
+ std::transform(iter->second.begin(), iter->second.end(),
+ std::back_inserter(output_shapes),
+ [](const InferredShape& inferred_shape) {
+ return inferred_shape.shape;
+ });
+ n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
+ }
+
+ return Status::OK();
+}
+
+Status PreprocessForEncapsulation(Graph* g,
+ const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name) {
+ TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name,
+ outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name,
+ outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
+ g, xla_computation_attr_name, outside_compilation_attr_name));
+ return Status::OK();
+}
+
+Status PostprocessForEncapsulation(
+ Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name,
+ const std::unordered_map& clusters) {
+ // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2,
+ // but the node name won't change. Record cluster node name for
+ // `AddControlDependencies`.
+ std::unordered_map cluster_node_names;
+ for (const auto& iter : clusters) {
+ cluster_node_names[iter.first] = iter.second.node->name();
+ }
+
+ TF_RETURN_IF_ERROR(
+ RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g));
+ TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g));
+ TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names));
+ return Status::OK();
+}
+
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Remove edges from source node to outside compilation nodes, and edges
+ // from outside compilation nodes to sink node.
+ std::vector edges_to_remove;
+ for (const Edge* e : g->source_node()->out_edges()) {
+ if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (const Edge* e : g->sink_node()->in_edges()) {
+ if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+
+ TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..e363bc5754ac395bae262dc67a780a0173efaf5e
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_util.h
@@ -0,0 +1,210 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains some utility functions for encapsulating XLA computation
+// in host graph and encapsulating outside compilation in XLA computation.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Attribute marking output tensor shapes inferred by XLA. Attribute value is
+// a list of PartialTensorShape objects.
+extern const char kXlaInferredShapesAttrName[];
+
+// Infer output shapes for outside compilation nodes which have output data
+// edges to XLA computation nodes. These shapes will be used later by XLA
+// compiler as output shapes of the outside compilation's XlaHostCompute op.
+// XLA computation nodes will be mark by attr `xla_computation_attr_name`;
+// outside compilation nodes will be marked by both attr
+// `xla_computation_attr_name` and `outside_compilation_attr_name`.
+//
+// Those outside compilation nodes will be marked with attribute
+// `kXlaInferredShapesAttrName`.
+//
+// We have to perform shape inference before encapsulation because after
+// encapsulation, some nodes will be encapsulated into function call, and shape
+// inference does not handle function call at the moment.
+Status PerformStaticShapeInferenceBeforeEncapsulation(
+ Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name);
+
+// Attribute indicating that some ops in other XLA computation has control
+// dependency on this node. Attribute value will be a list of string (XLA
+// computation names).
+extern const char kXlaConnectedToOtherXlaComputationAttrName[];
+
+// Attribute indicating that this node has control dependency on some ops in
+// other XLA computation. Attribute value will be a list of string (XLA
+// computation names).
+extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
+
+// Attribute indicating that this node has control dependencies on some other
+// nodes. Attribute value will be a list of string (node names).
+extern const char kXlaControlDependenciesAttrName[];
+
+// Attribute indicating that this is an Identity node added to act as a bridge
+// between different XLA computations. Attribute value will be string (source
+// node name).
+extern const char kBridgeSourceNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// string (original input node name).
+extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// int (src_output for original edge).
+extern const char kOutsideCompilationToHostSrcOutputAttrName[];
+
+// Attribute indicating that some ops in this node's XLA computation has control
+// dependency on this node. Attribute value will always be "true".
+extern const char kXlaConnectedToXlaComputationAttrName[];
+
+// Attribute indicating that this node has control dependency on some ops in
+// this node's XLA computation. Attribute value will always be "true".
+extern const char kXlaConnectedFromXlaComputationAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an host node. Attribute value will be string
+// (original input node name).
+extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for a host node. Attribute value will be int (src_output
+// for original edge).
+extern const char kHostToOutsideCompilationSrcOutputAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// string (original input node name).
+extern const char kOutsideCompilationOriginalNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// int (src_output for original edge).
+extern const char kOutsideCompilationSrcOutputAttrName[];
+
+// Attribute indicating that this node has control dependencies on some other
+// nodes within the same XLA cluster. Attribute value will be a list of string
+// (node names).
+extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
+
+// Preprocesses edges between different XLA clusters for encapsulation. It will
+// perform the following operations in order:
+//
+// 1a. For control edges between outside compilation and another XLA
+// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
+// = XLA computation node name" to the outside compilation node.
+// 1b. For control edges between different outside compilations (in different
+// XLA computations), remove the edge and add attr
+// "kXlaControlDependenciesAttrName = src node name" to dst node.
+// 1c. For control edges between outside compilation and host computation,
+// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
+// name" to dst node.
+// 2. For data edges between different XLA computations, if either src or dst
+// is outside compilation, add an Identity node in between the edge. The
+// identity node will have attr kBridgeSourceNodeAttrName.
+// 3. For data edges between outside compilation and host computation, remove
+// the edge and create a Placeholder node as dst node's input.
+Status PreprocessForEncapsulation(Graph* g,
+ const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name);
+
+// Information for XLA computation.
+struct XlaClusterInfo {
+ // Add an explicitly-defined default constructor for this class.
+ //
+ // The compiler may delete the default constructor here because
+ // host_compute_core is a const member whose type (std::map) doesn't
+ // necessarily have a user provided constructor -- while libc++ and
+ // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at
+ // least >= 7.3 does not. See also c++11 [class.ctor] p5.
+ //
+ // TODO(klimek): In c++17 we'll be able to initialize host_compute_core
+ // without losing aggregate initialization, which allows us to get rid of
+ // the constructor definitions again.
+ XlaClusterInfo() {}
+ XlaClusterInfo(const string& cluster_name,
+ const NameAttrList& func_name_attrs, Node* node,
+ const std::map& host_compute_core)
+ : cluster_name(cluster_name),
+ func_name_attrs(func_name_attrs),
+ node(node),
+ host_compute_core(host_compute_core) {}
+ // XLA cluster name. It might be different from `func_name`.
+ const string cluster_name;
+ // Name and attributes of XLA computation function.
+ const NameAttrList func_name_attrs;
+ // The XLA computation node in the graph.
+ Node* node;
+ // A mapping from outside compilation cluster name to its device assignment.
+ const std::map host_compute_core;
+};
+
+// Postprocesses edges between different XLA clusters for encapsulation. This
+// function reverts what `PreprocessForEncapsulation` did. It will perform the
+// following operations in order:
+//
+// 1. Remove Placeholder nodes between outside compilation and host computation
+// (created in `PreprocessForEncapsulation` step 3).
+// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
+// 3a. Reconnect control edges between outside compilation and another XLA
+// computation (marked by `PreprocessForEncapsulation` step 1a).
+// 3b. Reconnect control edges between different outside compilations (marked by
+// `PreprocessForEncapsulation` step 1b).
+// 3c. Reconnect control edges between outside compilation and host computation
+// (marked by `PreprocessForEncapsulation` step 1c).
+Status PostprocessForEncapsulation(
+ Graph* g, const string& xla_computation_attr_name,
+ const string& outside_compilation_attr_name,
+ const std::unordered_map& clusters);
+
+// Preprocesses edges within the same XLA cluster. It will perform the following
+// operations in order:
+//
+// 0. Remove edges from source node to outside compilation nodes, and edges
+// from outside compilation nodes to sink node.
+// 1a. For edges between different outside compilation clusters, remove the edge
+// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
+// name" to dst node.
+// 1b. For control edges between outside compilation and its XLA computation,
+// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
+// outside compilation node.
+// 2. For data edges between different outside compilations, remove the edge
+// and create a Placeholder node as dst node's input.
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
+
+// Postprocesses edges within the same XLA cluster. This function reverts what
+// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
+// following operations in order:
+//
+// 1. Remove Placeholder nodes between different outside compilations (created
+// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
+// 2a. Reconnect control edges between different outside compilations (marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
+// Notice that control edges marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
+// They are handled in `RewriteOutsideCompilationSubgraphFn`.
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..3b8b49cb92f3e453883a8e64e12ce3748a5173f6
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -0,0 +1,394 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_util.h"
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
+ // Build the graph:
+ // "add" = "const_0" + "const_1"
+ // "identity" = "add"
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {2});
+ Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {2});
+ Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
+ Output identity = ops::Identity(s.WithOpName("identity"), add);
+ Graph g(OpRegistry::Global());
+ TF_CHECK_OK(s.ToGraph(&g));
+
+ // "add" node is outside compilation node, "identity" node is XLA node.
+ auto node_index = g.BuildNodeNameIndex();
+ Node *add_node = node_index["add"], *identity_node = node_index["identity"];
+ add_node->AddAttr("_xla", "cluster");
+ add_node->AddAttr("_oc", "cluster");
+ identity_node->AddAttr("_xla", "cluster");
+ TF_CHECK_OK(
+ PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
+
+ // Check that only "add" node now has _xla_inferred_shapes attr.
+ std::vector nodes_with_inferred_shape;
+ for (Node *n : g.nodes()) {
+ if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
+ nodes_with_inferred_shape.push_back(n);
+ }
+ }
+ EXPECT_EQ(nodes_with_inferred_shape.size(), 1);
+ EXPECT_EQ(nodes_with_inferred_shape[0], add_node);
+ std::vector output_shapes;
+ TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName,
+ &output_shapes));
+ EXPECT_EQ(output_shapes.size(), 1);
+ TensorShapeProto shape_proto;
+ output_shapes[0].AsProto(&shape_proto);
+ EXPECT_EQ(shape_proto.dim_size(), 1);
+ EXPECT_EQ(shape_proto.dim(0).size(), 2);
+}
+
+TEST(PreprocessForEncapsulationTest, ControlEdges) {
+ // Build the graph:
+ // "const_0" and "const_1" in host computation
+ // "add" = "const_0" + "const_1" in XLA computation 0
+ // "identity0" = "add" in XLA computation 0 & outside compilation 0
+ // "identity1" = "identity0" in XLA computation 0
+ // "identity2" = "identity1" in host computation
+ // "identity3" = "identity2" in XLA computation 1
+ // "identity4" = "identity3" in XLA computation 1 & outside compilation 1
+ // "identity5" = "identity4" in XLA computation 1
+ // "identity6" = "identity5" in host computation
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
+ Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
+ Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
+ Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
+ Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
+ Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
+ Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
+ Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
+ Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
+ Graph g(OpRegistry::Global());
+ TF_CHECK_OK(s.ToGraph(&g));
+ auto node_index = g.BuildNodeNameIndex();
+
+ // Set XLA computation/outside compilation attr, and add control edges.
+ Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
+ *identity0_node = node_index["identity0"],
+ *identity1_node = node_index["identity1"],
+ *identity2_node = node_index["identity2"],
+ *identity3_node = node_index["identity3"],
+ *identity4_node = node_index["identity4"],
+ *identity5_node = node_index["identity5"];
+ add_node->AddAttr("_xla", "0");
+ identity0_node->AddAttr("_xla", "0");
+ identity0_node->AddAttr("_oc", "0");
+ identity1_node->AddAttr("_xla", "0");
+ identity3_node->AddAttr("_xla", "1");
+ identity4_node->AddAttr("_xla", "1");
+ identity4_node->AddAttr("_oc", "0");
+ identity5_node->AddAttr("_xla", "1");
+ // Case 1a: control edges between outside compilation and another XLA
+ // computation.
+ g.AddControlEdge(identity0_node, identity3_node);
+ g.AddControlEdge(identity1_node, identity4_node);
+ // Case 1b: control edges between different outside compilations.
+ g.AddControlEdge(identity0_node, identity4_node);
+ // Case 1c: control edges between outside compilation and host computation.
+ g.AddControlEdge(const0_node, identity0_node);
+ g.AddControlEdge(identity0_node, identity2_node);
+
+ TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
+
+ // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
+ // to the outside compilation node.
+ std::vector attr;
+ TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
+ kXlaConnectedToOtherXlaComputationAttrName, &attr));
+ EXPECT_EQ(attr.size(), 1);
+ EXPECT_EQ(attr[0], "1");
+ attr.clear();
+ TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
+ kXlaConnectedFromOtherXlaComputationAttrName, &attr));
+ EXPECT_EQ(attr.size(), 1);
+ EXPECT_EQ(attr[0], "0");
+ // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
+ attr.clear();
+ TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
+ kXlaControlDependenciesAttrName, &attr));
+ EXPECT_EQ(attr.size(), 1);
+ EXPECT_EQ(attr[0], "identity0");
+ // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
+ attr.clear();
+ TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
+ kXlaControlDependenciesAttrName, &attr));
+ EXPECT_EQ(attr.size(), 1);
+ EXPECT_EQ(attr[0], "const_0");
+ attr.clear();
+ TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
+ kXlaControlDependenciesAttrName, &attr));
+ EXPECT_EQ(attr.size(), 1);
+ EXPECT_EQ(attr[0], "identity0");
+}
+
+TEST(PreprocessForEncapsulationTest, DataEdges) {
+ // Build the graph:
+ // "const_0" and "const_1" in host computation
+ // "identityn0" = ("const_0", "const_1") in host computation 0
+ // "add0" = "const_0" + "const_1" in XLA computation 0
+ // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
+ // "identity0" = "add1" in XLA computation 0
+ // "add2" = "add1" + "identity0" in host computation
+ // "add3" = "add1" + "add2" in XLA computation 1
+ // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
+ // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
+ // outside compilation 0
+ // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
+ // outside compilation 0
+ // "identity1" = "add4" in XLA computation 1
+ // "identity2" = "identity1" in host computation
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
+ Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
+ auto identityn0 =
+ ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
+ Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
+ Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
+ Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
+ Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
+ Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
+ Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
+ Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
+ auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
+ {identityn0[0], identityn0[1]});
+ Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
+ Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
+ Graph g(OpRegistry::Global());
+ TF_CHECK_OK(s.ToGraph(&g));
+ auto node_index = g.BuildNodeNameIndex();
+
+ // Set XLA computation/outside compilation attr.
+ Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
+ *identity0_node = node_index["identity0"],
+ *add3_node = node_index["add3"], *add4_node = node_index["add4"],
+ *add5_node = node_index["add5"],
+ *identityn1_node = node_index["identityn_1"],
+ *identity1_node = node_index["identity1"];
+ add0_node->AddAttr("_xla", "0");
+ add1_node->AddAttr("_xla", "0");
+ add1_node->AddAttr("_oc", "0");
+ identity0_node->AddAttr("_xla", "0");
+ add3_node->AddAttr("_xla", "1");
+ add4_node->AddAttr("_xla", "1");
+ add4_node->AddAttr("_oc", "0");
+ add5_node->AddAttr("_xla", "1");
+ add5_node->AddAttr("_oc", "0");
+ identityn1_node->AddAttr("_xla", "1");
+ identityn1_node->AddAttr("_oc", "0");
+ identity1_node->AddAttr("_xla", "1");
+
+ TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
+
+ // Check input nodes for related data edges.
+ node_index = g.BuildNodeNameIndex();
+ // Step 2: add an Identity node between different XLA computations.
+ Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
+ EXPECT_NE(bridge_add1_add3, nullptr);
+ string str;
+ TF_CHECK_OK(
+ GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
+ EXPECT_EQ(str, "add1");
+ Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
+ EXPECT_NE(bridge_identity0_add4, nullptr);
+ // Step 3: add placeholder for edges between host computation and outside
+ // compilation.
+ EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
+ Node *add1_oc_to_host_placeholder =
+ node_index["add1_oc_to_host_placeholder_0"];
+ TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
+ kOutsideCompilationToHostOriginalNodeAttrName, &str));
+ EXPECT_EQ(str, "add1");
+ int i;
+ TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
+ kOutsideCompilationToHostSrcOutputAttrName, &i));
+ EXPECT_EQ(i, 0);
+ add4_node = node_index["add4"];
+ ASSERT_NE(add4_node, nullptr);
+ EXPECT_EQ(add4_node->def().input(0),
+ "bridge_identity0_add4_host_to_oc_placeholder_0");
+ Node *identity0_host_to_oc_placeholder =
+ node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
+ TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
+ kHostToOutsideCompilationOriginalNodeAttrName, &str));
+ EXPECT_EQ(str, "bridge_identity0_add4");
+ TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
+ kHostToOutsideCompilationSrcOutputAttrName, &i));
+ EXPECT_EQ(i, 0);
+
+ // Check different placeholder nodes are created for different src_output.
+ Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
+ *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
+ EXPECT_NE(placeholder0, nullptr);
+ EXPECT_NE(placeholder1, nullptr);
+ // Check we only have 2 placeholder nodes created for "identityn_0".
+ int placeholder_count = 0;
+ for (Node *n : g.nodes()) {
+ if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
+ string attr;
+ TF_CHECK_OK(GetNodeAttr(
+ n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
+ if (attr == "identityn_0") {
+ ++placeholder_count;
+ }
+ }
+ }
+ EXPECT_EQ(placeholder_count, 2);
+}
+
+TEST(PostprocessForEncapsulationTest, ControlEdges) {
+ // Build the graph:
+ // "const0"
+ // "identity0" = "const0" (XLA computation 0)
+ // "identity1" = "identity0"
+ // "identity2" = "identity1" (XLA computation 1)
+ // "identity3" = "identity2"
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
+ Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
+ Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
+ Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
+ Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
+ Graph g(OpRegistry::Global());
+ TF_CHECK_OK(s.ToGraph(&g));
+ auto node_index = g.BuildNodeNameIndex();
+
+ // Set XLA computation/outside compilation attr, and add control edges.
+ Node *const0_node = node_index["const0"],
+ *identity0_node = node_index["identity0"],
+ *identity1_node = node_index["identity1"],
+ *identity2_node = node_index["identity2"],
+ *identity3_node = node_index["identity3"];
+ identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName,
+ std::vector{"0"});
+ identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName,
+ std::vector{"1"});
+ identity3_node->AddAttr(kXlaControlDependenciesAttrName,
+ std::vector{"const0", "identity1"});
+
+ std::unordered_map clusters;
+ clusters["0"].node = identity0_node;
+ clusters["1"].node = identity2_node;
+ TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
+
+ // Case 3a: we have control edge identity0 -> identity1, and identity1 ->
+ // identity2.
+ bool edge_identity0_identity1 = false, edge_identity1_identity2 = false;
+ for (const Edge *e : g.edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+ if (e->src() == identity0_node && e->dst() == identity1_node) {
+ edge_identity0_identity1 = true;
+ } else if (e->src() == identity1_node && e->dst() == identity2_node) {
+ edge_identity1_identity2 = true;
+ }
+ }
+ EXPECT_TRUE(edge_identity0_identity1);
+ EXPECT_TRUE(edge_identity1_identity2);
+ // Case 3b: we have control edge const0 -> identity3, and identity1 ->
+ // identity3.
+ bool edge_const0_identity3 = false, edge_identity1_identity3 = false;
+ for (const Edge *e : g.edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+ if (e->src() == const0_node && e->dst() == identity3_node) {
+ edge_const0_identity3 = true;
+ } else if (e->src() == identity1_node && e->dst() == identity3_node) {
+ edge_identity1_identity3 = true;
+ }
+ }
+ EXPECT_TRUE(edge_const0_identity3);
+ EXPECT_TRUE(edge_identity1_identity3);
+}
+
+TEST(PostprocessForEncapsulationTest, DataEdges) {
+ // Build the graph:
+ // "const0" in outside compilation "0"
+ // "placeholder0" (for "const0") in host computation
+ // "add0" = "placeholder0" + "placeholder0" in host computation
+ // "placeholder1" (for "add0") in outside compilation 1
+ // "add1" = "placeholder1" + "placeholder1" in outside compilation 1
+ //
+ // "bridge" = "placeholder0" in host computation
+ // "placeholder2" (for "bridge") in outside compilation 1
+ // "add2" = "placeholder2" + "placeholder2" in outside compilation 1
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
+ Output placeholder0 =
+ ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32);
+ Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0);
+ Output placeholder1 =
+ ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32);
+ Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1);
+ Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0);
+ Output placeholder2 =
+ ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32);
+ Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2);
+ Graph g(OpRegistry::Global());
+ TF_CHECK_OK(s.ToGraph(&g));
+ auto node_index = g.BuildNodeNameIndex();
+
+ // Set related attributes.
+ Node *placeholder0_node = node_index["placeholder0"];
+ placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName,
+ "const0");
+ placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0);
+ Node *placeholder1_node = node_index["placeholder1"];
+ placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
+ "add0");
+ placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
+ Node *bridge_node = node_index["bridge"];
+ bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0");
+ Node *placeholder2_node = node_index["placeholder2"];
+ placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
+ "bridge");
+ placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
+
+ std::unordered_map clusters;
+ TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
+
+ // Result graph should be:
+ // "add0" = "const0" + "const0"
+ // "add1" = "add0" + "add0"
+ // "add2" = "const0" + "const0"
+ node_index = g.BuildNodeNameIndex();
+ EXPECT_EQ(node_index.size(), 6);
+ EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0");
+ EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0");
+ EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0");
+ EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0");
+ EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0");
+ EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 2ce6fa73fc448ca83fa392aa909cb385453eb8b6..d334100aa4a915a87fb05d371e0e3379a7ee05f2 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors,
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
- "Undeclared output of XLA computation. A common cause of this error "
- "is variable initializers that depend on the XLA computation. Edge: ",
+ "Undeclared output of XLA computation. Some common causes of this "
+ "error are: 1) variable initializers that depend on the XLA "
+ "computation; 2) gradient computations that depend on the XLA "
+ "computation, which can be mitigated by moving gradient computations "
+ "inside XLA computation. Offending edge: ",
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
e->dst_input());
}
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e3c7e2f89be9b37b51a633dabb099969c181013f
--- /dev/null
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -0,0 +1,941 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
+
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Add a key placeholder node to the graph. The key placeholder node will be
+// used as input for XlaRecvAtHost/XlaSendFromHost nodes.
+xla::StatusOr AddHostComputeKeyPlaceholder(
+ const string& xla_cluster_name, Graph* g) {
+ NodeDef key_def;
+ NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
+ "Placeholder");
+ builder.Attr("dtype", DT_STRING);
+ builder.Attr("shape", PartialTensorShape({2}));
+ builder.Attr("_host_compute_call_node", xla_cluster_name);
+ Status s = builder.Finalize(&key_def);
+ if (!s.ok()) return s;
+
+ Node* n = g->AddNode(key_def, &s);
+ if (!s.ok()) return s;
+ return n;
+}
+
+// Returns if the node is a XLA computation key placeholder.
+bool IsKeyPlaceholderNode(const Node& n) {
+ return n.type_string() == "Placeholder" &&
+ absl::EndsWith(n.name(), "_key_placeholder");
+}
+
+// Returns nodes with given type.
+std::vector GatherNodesWithType(const Graph& g, const string& type) {
+ std::vector result;
+ for (Node* n : g.nodes()) {
+ if (n->type_string() == type) {
+ result.push_back(n);
+ }
+ }
+ return result;
+}
+
+// Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
+Status GetArgDataTypes(const std::vector& arg_nodes,
+ std::vector* recv_at_host_dtypes) {
+ recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
+ for (auto* n : arg_nodes) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
+ (*recv_at_host_dtypes)[index] = dtype;
+ }
+ for (int i = 0; i < recv_at_host_dtypes->size(); i++) {
+ if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
+ return errors::Internal("Cannot get datatype for input ", i);
+ }
+ }
+ return Status::OK();
+}
+
+// Builds XlaRecvAtHost node.
+xla::StatusOr BuildRecvAtHostNode(
+ Graph* g, const string& oc_cluster_name,
+ const std::vector& recv_at_host_dtypes, Node* key_placeholder) {
+ NodeDefBuilder recv_at_host_builder(
+ absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
+ "_XlaRecvAtHost");
+ NodeDef recv_at_host_def;
+ recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
+ // The correct device_ordinal will be inserted during replication in a
+ // subsequent rewrite.
+ recv_at_host_builder.Attr("device_ordinal", 0);
+ recv_at_host_builder.Attr(
+ "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
+ TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
+ Status s;
+ Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ return recv_at_host_node;
+}
+
+// Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
+xla::StatusOr ReplaceArgNodesWithRecvAtHostNode(
+ Graph* g, const string& oc_cluster_name,
+ std::vector* recv_at_host_dtypes, Node* key_placeholder) {
+ // TODO(b/77601805): use out nodes for source node, instead of traversing all
+ // nodes.
+ std::vector arg_nodes = GatherNodesWithType(*g, "_Arg");
+ TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
+ TF_ASSIGN_OR_RETURN(
+ Node * recv_at_host_node,
+ BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
+ key_placeholder));
+ for (auto* n : arg_nodes) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ // Record out edges and remove `n` before adding those edges to RecvAtHost.
+ // This is to avoid multiple producers.
+ std::vector out_edge_info;
+ for (auto edge : n->out_edges()) {
+ out_edge_info.push_back(
+ {edge->dst(), edge->src_output(), edge->dst_input()});
+ }
+ g->RemoveNode(n);
+ for (const OutEdgeInfo& edge : out_edge_info) {
+ if (edge.dst_input == Graph::kControlSlot) {
+ g->AddControlEdge(recv_at_host_node, edge.dst);
+ } else {
+ g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
+ }
+ }
+
+ // Rewrite dst nodes because their input changed.
+ for (int i = 0; i < out_edge_info.size(); i++) {
+ const OutEdgeInfo edge = out_edge_info[i];
+ if (edge.dst_input == Graph::kControlSlot) {
+ continue;
+ }
+
+ Node* dst = edge.dst;
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(edge.dst_input) =
+ absl::StrCat(recv_at_host_node->name(), ":", index);
+ TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
+
+ // Other edges might have `dst` as dst node as well. Update those edges
+ // with `dst_replace`.
+ for (int j = i + 1; j < out_edge_info.size(); j++) {
+ if (out_edge_info[j].dst == dst) {
+ out_edge_info[j].dst = dst_replace;
+ }
+ }
+ }
+ }
+ g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
+ return recv_at_host_node;
+}
+
+// Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
+Status GetRetDataTypes(const std::vector& ret_nodes,
+ std::vector* send_from_host_dtypes) {
+ send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
+ for (auto* n : ret_nodes) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
+ (*send_from_host_dtypes)[index] = dtype;
+ }
+ for (int i = 0; i < send_from_host_dtypes->size(); i++) {
+ if ((*send_from_host_dtypes)[i] == DT_INVALID) {
+ return errors::Internal("Cannot get datatype for output ", i);
+ }
+ }
+ return Status::OK();
+}
+
+// Builds XlaSendFromHost node.
+xla::StatusOr BuildSendFromHostNode(
+ Graph* g, const string& oc_cluster_name,
+ const std::vector& ret_nodes,
+ const std::vector& send_from_host_dtypes, Node* key_placeholder) {
+ NodeDefBuilder send_from_host_builder(
+ absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
+ "_XlaSendFromHost");
+ NodeDef send_from_host_def;
+ send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
+ // The correct device_ordinal will be inserted during replication in a
+ // subsequent rewrite.
+ send_from_host_builder.Attr("device_ordinal", 0);
+ send_from_host_builder.Attr(
+ "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ std::vector inputs(send_from_host_dtypes.size());
+ for (auto* n : ret_nodes) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ if (index < 0 || index >= send_from_host_dtypes.size()) {
+ return errors::Internal("Invalid _Retval index: ", index);
+ }
+ for (auto edge : n->in_edges()) {
+ inputs[index] =
+ NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
+ edge->src()->output_type(edge->src_output())};
+ }
+ }
+ send_from_host_builder.Input(inputs);
+ send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
+ TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
+ Status s;
+ Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ return send_from_host_node;
+}
+
+// Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
+xla::StatusOr ReplaceRetNodesWithSendFromHostNode(
+ Graph* g, const string& oc_cluster_name,
+ std::vector* send_from_host_dtypes, Node* key_placeholder) {
+ // TODO(b/77601805): use in nodes for sink node, instead of traversing all
+ // nodes.
+ std::vector ret_nodes = GatherNodesWithType(*g, "_Retval");
+ TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
+ TF_ASSIGN_OR_RETURN(
+ Node * send_from_host_node,
+ BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
+ *send_from_host_dtypes, key_placeholder));
+ for (auto* n : ret_nodes) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ for (auto edge : n->in_edges()) {
+ if (edge->src_output() == Graph::kControlSlot) {
+ g->AddControlEdge(edge->src(), send_from_host_node);
+ } else {
+ g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
+ }
+ }
+ g->RemoveNode(n);
+ }
+ g->AddEdge(key_placeholder, 0, send_from_host_node,
+ send_from_host_dtypes->size());
+ return send_from_host_node;
+}
+
+// Returns input shapes (excluding key placeholder) for `send_from_host_node`
+// if they are all fully defined; absl::nullopt otherwise.
+absl::optional> GetInferredInputShapes(
+ int num_inputs, Node* send_from_host_node) {
+ std::vector results(num_inputs);
+ for (int i = 0; i < num_inputs; i++) {
+ const Edge* e;
+ if (!send_from_host_node->input_edge(i, &e).ok()) {
+ return absl::nullopt;
+ }
+
+ std::vector shapes;
+ if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
+ .ok()) {
+ return absl::nullopt;
+ }
+
+ const PartialTensorShape shape = shapes[e->src_output()];
+ if (!shape.IsFullyDefined()) {
+ return absl::nullopt;
+ }
+
+ results[e->dst_input()] = shape;
+ }
+ return results;
+}
+
+// Builds XlaHostCompute NodeDef from the outside compilation call node.
+xla::StatusOr BuildXlaHostComputeNodeDef(
+ const Node* call_node, const std::map& host_compute_core) {
+ string original_oc_name;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
+ NodeDefBuilder host_compute_builder(
+ absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
+ "XlaHostCompute");
+
+ // Copy all attributes.
+ for (auto attr : call_node->attrs()) {
+ host_compute_builder.Attr(attr.first, attr.second);
+ }
+
+ // Populate tpu_core assignment.
+ const auto iter = host_compute_core.find(original_oc_name);
+ if (iter != host_compute_core.end()) {
+ int core = iter->second;
+ host_compute_builder.Attr("tpu_core", core);
+ }
+
+ // Populate inputs.
+ std::vector input_dtypes;
+ TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
+ std::vector inputs(input_dtypes.size());
+ for (auto e : call_node->in_edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) {
+ return errors::Internal("Invalid dst_input: ", e->dst_input());
+ }
+ inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
+ e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
+ }
+ host_compute_builder.Input(inputs);
+
+ NodeDef new_def;
+ TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
+ return new_def;
+}
+
+// Replace outside compilation function call node with XlaHostCompute node.
+// If the function call node has no input/output edges, we will just remove it
+// and not create a XlaHostCompute node.
+Status ReplaceOrRemoveOutsideCompilationCallNode(
+ Graph* g, Node* call_node, const std::map& host_compute_core) {
+ // If the function call node has no input/output edges, just remove it.
+ bool has_edge = false;
+ for (auto e : call_node->in_edges()) {
+ if (!e->IsControlEdge() || e->src() != g->source_node()) {
+ has_edge = true;
+ break;
+ }
+ }
+ for (auto e : call_node->out_edges()) {
+ if (!e->IsControlEdge() || e->dst() != g->sink_node()) {
+ has_edge = true;
+ break;
+ }
+ }
+ if (!has_edge) {
+ VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
+ g->RemoveNode(call_node);
+ return Status::OK();
+ }
+
+ // Build XlaHostCompute NodeDef.
+ TF_ASSIGN_OR_RETURN(NodeDef node_def,
+ BuildXlaHostComputeNodeDef(call_node, host_compute_core));
+ TF_ASSIGN_OR_RETURN(Node * host_compute_node,
+ ReplaceNode(g, call_node, node_def));
+ VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
+
+ return Status::OK();
+}
+
+// For an XLA computation, builds host side graph given all outside compilation
+// graphs inside it. The host side graph contains:
+// 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
+// XlaSendFromHost to this sequencer node, so all outside compilation nodes
+// will be executed *before* this sequencer).
+// 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
+// replace this node with compilation result node.
+// 3) all outside compilation graphs.
+Status ConstructHostGraph(
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
+ const std::vector& outside_compilation_host_graphs,
+ FunctionLibraryDefinition* fld, std::unique_ptr