diff --git a/README.md b/README.md
index 05fcb23f7edd657f2ea495d848fadc226e56b524..bbebfc723a65ef1365362cb56298a9399c85e179 100644
--- a/README.md
+++ b/README.md
@@ -81,13 +81,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
| Build Type | Status | Artifacts |
| --- | --- | --- |
-| **Linux CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
-| **Linux GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
-| **Linux XLA** | TBA | TBA |
-| **MacOS** |  | [pypi](https://pypi.org/project/tf-nightly/) |
-| **Windows CPU** | [](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) |
-| **Windows GPU** | [](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
-| **Android** | [](https://ci.tensorflow.org/job/tensorflow-master-android) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) |
+| **Linux CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Linux GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Linux XLA** |  | TBA |
+| **MacOS** |  | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Android** |  | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
### Community Supported Builds
@@ -97,7 +97,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
| **IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
-| **Linux CPU with Intel® MKL-DNN®** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA |
+| **Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
+| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6| |[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
## For more information
diff --git a/RELEASE.md b/RELEASE.md
index 6b67072f8ecafa08c747f8296c7c2a59eb2350fa..078aafd3746e5ce5c16af15de80d99c1a9e8c567 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,68 @@
+# Release 1.10.0
+
+## Major Features And Improvements
+
+* The `tf.lite` runtime now supports `complex64`.
+* Initial Bigtable integration for `tf.data`.
+* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation.
+* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`.
+* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018.
+* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings)
+
+## Breaking Changes
+
+* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites).
+* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake.
+
+## Bug Fixes and Other Changes
+
+* `tf.data`:
+ * `tf.contrib.data.group_by_reducer()` is now available via the public API.
+ * `tf.contrib.data.choose_from_datasets()` is now available via the public API.
+ * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
+* `tf.estimator`:
+ * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export.
+ * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases.
+ * Support sparse_combiner in canned Linear Estimators.
+ * Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`.
+ * Adding ranking support for boosted trees.
+ * Adding center bias option for boosted trees.
+* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
+* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
+* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory).
+* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`.
+* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`.
+* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated.
+* Make the Bigtable client connection pool configurable & increase the default # of connections for performance.
+* Added derivative of `tf.random_gamma` with respect to the alpha parameter.
+* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a.
+* Modified Bessel functions of order zero and one.
+* Add FillTriangular Bijector to create triangular matrices.
+* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`.
+* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`.
+* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`.
+* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool.
+* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized.
+* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram
+* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency).
+* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints:
+ * New endpoints in `tf.image` namespace: `tf.image.extract_image_patches`
+ * New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`.
+ * New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`.
+ * New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`.
+ * New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`).
+ * New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile`
+ * New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`.
+ * New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`.
+ * New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`).
+
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius
+
# Release 1.9.0
## Major Features And Improvements
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 388ca3f293ebfa120037b75fe70c66b9d715c051..f8cd6820244aa05724ce0980419eb7b77962ff91 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -381,6 +381,14 @@ config_setting(
},
)
+# Setting to use when loading kernels dynamically
+config_setting(
+ name = "dynamic_loaded_kernels",
+ define_values = {
+ "dynamic_loaded_kernels": "true",
+ },
+)
+
config_setting(
name = "using_cuda_nvcc",
define_values = {
@@ -408,14 +416,6 @@ config_setting(
visibility = ["//visibility:public"],
)
-# TODO(laigd): consider removing this option and make TensorRT enabled
-# automatically when CUDA is enabled.
-config_setting(
- name = "with_tensorrt_support",
- values = {"define": "with_tensorrt_support=true"},
- visibility = ["//visibility:public"],
-)
-
package_group(
name = "internal",
packages = [
@@ -441,11 +441,6 @@ filegroup(
),
)
-filegroup(
- name = "docs_src",
- data = glob(["docs_src/**/*.md"]),
-)
-
cc_library(
name = "grpc",
deps = select({
@@ -589,6 +584,7 @@ exports_files(
gen_api_init_files(
name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
+ api_version = 1,
root_init_template = "api_template.__init__.py",
)
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 10bc8cdbee5a9df6d2084c10adab4ed6e5e6f0d3..19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -2389,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
TF_Output* dx, TF_Status* status, TF_Output* dy) {
+ TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
+}
+
+void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
+ int ny, TF_Output* x, int nx, TF_Output* dx,
+ TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Adding gradients is not supported in Android. File a bug at "
@@ -2405,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
const int first_new_node_id = g->graph.num_node_ids();
+ string prefix_cmp;
+ const char* child_scope_name;
+ if (prefix == nullptr) {
+ child_scope_name = "gradients";
+ } else {
+ prefix_cmp = string(prefix) + "/";
+ // The operation should fail if the provided name prefix has already been
+ // used in this graph
+ for (const auto& pair : g->name_map) {
+ const string& name = pair.first;
+ if (name.compare(prefix) == 0 ||
+ tensorflow::str_util::StartsWith(name, prefix_cmp)) {
+ status->status = InvalidArgument(
+ "prefix [", prefix,
+ "] conflicts with existing node in the graph named [", name, "]");
+ return;
+ }
+ }
+ child_scope_name = prefix;
+ }
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
- .NewSubScope("gradients");
+ .NewSubScope(child_scope_name);
if (dx != nullptr) {
std::vector dx_arg = OutputsFromTFOutputs(dx, ny);
@@ -2422,6 +2449,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
+
+ // Adding the gradients to the graph can alter the prefix to prevent
+ // name collisions only if this prefix has not been provided explicitly
+ // by the user. If it was provided, assert that it remained intact.
+ if (prefix != nullptr &&
+ !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The gradients prefix have been unexpectedly altered when "
+ "adding the nodes to the graph. This is a bug. Please file an "
+ "issue at https://github.com/tensorflow/tensorflow/issues.");
+ return;
+ }
// We have a convoluted scheme here: Using the C++ graph construction API
// to add potentially many nodes to the graph without running the checks
// (such as uniqueness of the names of nodes) we run with other functions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index c8ae6f2dd1780c4fe50ff1924be8d2e9a7502cf0..850f6ecd637d768bca99720e0add07680829e17a 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1131,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
+//
// `dx` are used as initial gradients (which represent the symbolic partial
// derivatives of some loss function `L` w.r.t. `y`).
// `dx` must be nullptr or have size `ny`.
@@ -1139,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
// The partial derivatives are returned in `dy`. `dy` should be allocated to
// size `nx`.
//
+// Gradient nodes are automatically named under the "gradients/" prefix. To
+// guarantee name uniqueness, subsequent calls to the same graph will
+// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ...
+// See TF_AddGradientsWithPrefix, which provides a means to specify a custom
+// name prefix for operations added to a graph to compute the gradients.
+//
// WARNING: This function does not yet support all the gradients that python
// supports. See
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
@@ -1147,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
+// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
+// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
+// This is a variant of TF_AddGradients that allows to caller to pass a custom
+// name prefix to the operations added to a graph to compute the gradients.
+//
+// `dx` are used as initial gradients (which represent the symbolic partial
+// derivatives of some loss function `L` w.r.t. `y`).
+// `dx` must be nullptr or have size `ny`.
+// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all
+// shapes in `y`.
+// The partial derivatives are returned in `dy`. `dy` should be allocated to
+// size `nx`.
+// `prefix` names the scope into which all gradients operations are being added.
+// `prefix` must be unique within the provided graph otherwise this operation
+// will fail. If `prefix` is nullptr, the default prefixing behaviour takes
+// place, see TF_AddGradients for more details.
+//
+// WARNING: This function does not yet support all the gradients that python
+// supports. See
+// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
+// for instructions on how to add C++ more gradients.
+TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix,
+ TF_Output* y, int ny,
+ TF_Output* x, int nx,
+ TF_Output* dx, TF_Status* status,
+ TF_Output* dy);
+
// Create a TF_Function from a TF_Graph
//
// Params:
@@ -1236,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
int noutputs, const TF_Output* outputs, const char* const* output_names,
const TF_FunctionOptions* opts, const char* description, TF_Status* status);
+// Returns the name of the graph function.
+// The return value points to memory that is only usable until the next
+// mutation to *func.
+TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func);
+
// Write out a serialized representation of `func` (as a FunctionDef protocol
// message) to `output_func_def` (allocated by TF_NewBuffer()).
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 170046c8024dc85c899108b254cd3a95a3be4096..69b3ffe2a1f620e346405607ecf742fb863aa644 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -84,6 +84,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
return ret;
}
+TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
+ tensorflow::RunOptions options;
+ if (enable_full_trace) {
+ options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
+ } else {
+ options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
+ }
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(options, ret));
+ return ret;
+}
+
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 2d81c01e0dd056e9beb3b45f24809381554a7924..6617c5a572e90e78369f73d714f39942f213040f 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -70,6 +70,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
unsigned char enable_xla_compilation,
unsigned char gpu_memory_allow_growth);
+// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level
+// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE
+// otherwise.
+TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions(
+ unsigned char enable_full_trace);
+
// Returns the graph content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 384e6c8cb97022264c5327da5ca5861057608fbe..a2c5a42c11361779de61b515e0f08dcc45e609b9 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
return tf_function;
}
+const char* TF_FunctionName(TF_Function* func) {
+ return func->fdef.signature().name().c_str();
+}
+
void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
const TF_Function* grad, TF_Status* status) {
if (func == nullptr) {
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index f7ca219c896b2a7c07fc4d0739c70f2666652672..73fe73769bc1219ce865149d67d333c53371ccc5 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test {
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(func_, nullptr);
+ ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
@@ -1618,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
+// This test only works when the TF build includes XLA compiler. One way to set
+// this up is via bazel build option "--define with_xla_support=true".
+//
+// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
+// something like TENSORFLOW_CAPI_USE_XLA.
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST_F(CApiFunctionTest, StatelessIf_XLA) {
+ TF_Function* func;
+ const std::string funcName = "BranchFunc";
+ DefineFunction(funcName.c_str(), &func);
+ TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* feed = Placeholder(host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_OperationDescription* desc =
+ TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
+ TF_AddInput(desc, {true_cond, 0});
+ TF_Output inputs[] = {{feed, 0}};
+ TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_SetAttrType(desc, "Tcond", TF_BOOL);
+ TF_DataType inputType = TF_INT32;
+ TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
+ TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
+ TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
+ TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
+ TF_SetDevice(desc, "/device:XLA_CPU:0");
+ auto op = TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ ASSERT_NE(op, nullptr);
+
+ // Create a session for this graph.
+ CSession csession(host_graph_, s_, /*use_XLA*/ true);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Run the graph.
+ csession.SetInputs({{feed, Int32Tensor(17)}});
+ csession.SetOutputs({op});
+ csession.Run(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Tensor* out = csession.output_tensor(0);
+ ASSERT_TRUE(out != nullptr);
+ EXPECT_EQ(TF_INT32, TF_TensorType(out));
+ EXPECT_EQ(0, TF_NumDims(out)); // scalar
+ ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
+ int32* output_contents = static_cast(TF_TensorData(out));
+ EXPECT_EQ(-17, *output_contents);
+
+ // Clean up
+ csession.CloseAndDelete(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_DeleteFunction(func);
+}
+#endif // TENSORFLOW_EAGER_USE_XLA
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index e674b1623cf540eb8024d9be5ed8d77aa2fe17ba..aa2a537f03be31ae45ff3d6f7815b449d661cf9c 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1483,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test {
BuildSuccessGraph(inputs, outputs);
BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
- AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
-
+ AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
+ grad_outputs);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Compare that the graphs match.
@@ -1505,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test {
BuildErrorGraph(inputs, outputs);
- AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
+ AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
+ grad_outputs);
string expected_msg =
"No gradient defined for op: TestOpWithNoGradient. Please see "
@@ -1549,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test {
EXPECT_EQ(*a_data, *b_data);
}
- void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
- TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
+ void AddGradients(bool grad_inputs_provided, const char* prefix,
+ TF_Output* inputs, int ninputs, TF_Output* outputs,
+ int noutputs, TF_Output* grad_outputs) {
if (grad_inputs_provided) {
TF_Output grad_inputs[1];
const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
TF_Operation* grad_inputs_op =
FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
grad_inputs[0] = TF_Output{grad_inputs_op, 0};
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
- s_, grad_outputs);
+ TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
+ ninputs, grad_inputs, s_, grad_outputs);
} else {
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
- grad_outputs);
+ TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
+ ninputs, nullptr, s_, grad_outputs);
}
}
@@ -1706,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test {
return op;
}
+ void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
+ const char* prefix2 = nullptr) {
+ TF_Output inputs[2];
+ TF_Output outputs[1];
+ TF_Output grad_outputs[2];
+
+ BuildSuccessGraph(inputs, outputs);
+
+ AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
+ if (prefix2 != nullptr) {
+ AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
+ }
+ }
+
TF_Status* s_;
TF_Graph* graph_;
TF_Graph* expected_graph_;
@@ -1725,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
+ BuildGraphAndAddGradientsWithPrefixes("Const_0");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
ASSERT_TRUE(t != nullptr);
ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 24eb6c069b21349fce288db3e79fbf14e824ad11..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+static void BoolDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast(data);
+}
+
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast(data);
}
@@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast(data);
}
+TF_Tensor* BoolTensor(bool v) {
+ const int num_bytes = sizeof(bool);
+ bool* values = new bool[1];
+ values[0] = v;
+ return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
+ nullptr);
+}
+
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op;
}
+TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index 38313d647ca93d4779bb1325f8ed7bde4b743879..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -31,6 +31,8 @@ using ::tensorflow::string;
typedef std::unique_ptr
unique_tensor_ptr;
+TF_Tensor* BoolTensor(int32_t v);
+
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
@@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");
+TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
+ const char* name = "scalar");
+
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 6c510536d6f2a586b91baf96fa41b779db2c8d35..dfb1c9a37644c726e1eabab775593596d5b556b9 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices(
tensorflow::Status CreateRemoteContexts(
const std::vector& remote_workers, int64 rendezvous_id,
- const tensorflow::ServerDef& server_def,
+ int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
@@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
+ request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
@@ -150,8 +151,9 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::Status::OK();
}
-tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
- TFE_Context** ctx) {
+tensorflow::Status UpdateTFE_ContextWithServerDef(
+ int keep_alive_secs, const tensorflow::ServerDef& server_def,
+ TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@@ -165,12 +167,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
} \
} while (0);
- string worker_name = tensorflow::strings::StrCat(
- "/job:", opts->server_def.job_name(),
- "/replica:0/task:", opts->server_def.task_index());
+ string worker_name =
+ tensorflow::strings::StrCat("/job:", server_def.job_name(),
+ "/replica:0/task:", server_def.task_index());
std::unique_ptr server;
- LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
tensorflow::GrpcServer* grpc_server =
dynamic_cast(server.get());
@@ -202,15 +204,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
- remote_workers, rendezvous_id, opts->server_def,
- remote_eager_workers.get(), opts->async, &remote_contexts));
+ remote_workers, rendezvous_id, keep_alive_secs, server_def,
+ remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
- session_name, opts->server_def, true));
+ session_name, server_def, true));
std::shared_ptr worker_session;
TF_RETURN_IF_ERROR(
@@ -221,10 +223,11 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
- *ctx = new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, device_mgr, r, std::move(server),
- std::move(remote_eager_workers),
- std::move(remote_device_mgr), remote_contexts);
+
+ ctx->context.InitializeRemote(std::move(server),
+ std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts,
+ r, device_mgr, keep_alive_secs);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@@ -249,15 +252,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status) {
- if (!options->server_def.ParseFromArray(proto, proto_len)) {
- status->status = tensorflow::errors::InvalidArgument(
- "Invalid tensorflow.ServerDef protocol buffer");
- }
-}
-
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -267,12 +261,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- if (!opts->server_def.job_name().empty()) {
- TFE_Context* ctx = nullptr;
- status->status = NewRemoteAwareTFE_Context(opts, &ctx);
- return ctx;
- }
-
std::vector devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@@ -288,7 +276,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->async, std::move(device_mgr), r);
}
-void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
+void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
@@ -301,6 +289,22 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
+// Set server_def on the context, possibly updating it.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ int keep_alive_secs,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ return;
+ }
+ status->status =
+ UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
+}
+
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
@@ -336,7 +340,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
- DCHECK(h);
+ if (h == nullptr) return;
if (h->handle) {
h->handle->Unref();
}
@@ -348,6 +352,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
int result;
status->status = h->handle->NumDims(&result);
return result;
@@ -355,12 +364,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
tensorflow::Device* d = nullptr;
status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
@@ -368,6 +387,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
@@ -700,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
+void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
+
+void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
+
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index fdbd5374b2afe815c3a81b453930eb8f1fa351d3..a0ebc6fa0a22ed61be91c2974352c2988fb4cd92 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
-// A tensorflow.ServerDef specifies remote workers (in addition to the current
-// workers name). Operations created on this context can then be executed on
-// any of these remote workers by setting an appropriate device.
-//
-// If the following is set, all servers identified by the
-// ServerDef must be up when the context is created.
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status);
-
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
@@ -102,8 +92,7 @@ typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
const TFE_ContextOptions* opts, TF_Status* status);
-TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx,
- TF_Status* status);
+TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
@@ -128,6 +117,18 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
unsigned char async,
TF_Status* status);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ int keep_alive_secs,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
// Causes the calling thread to block till all ops dispatched in async mode
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
@@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
+// Some TF ops need a step container to be set to limit the lifetime of some
+// resources (mostly TensorArray and Stack, used in while loop gradients in
+// graph mode). Calling this on a context tells it to start a step.
+TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
+
+// Ends a step. When there is no active step (that is, every started step has
+// been ended) step containers will be cleared. Note: it is not safe to call
+// TFE_ContextEndStep while ops which rely on the step container may be running.
+TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 4c5077023d5bb3b83808bf3908e7110dd026e3ad..a5c0681e2e4eddae08954d9d0178ca96a3f8f29a 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -59,7 +59,6 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
- tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -73,23 +72,6 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
- explicit TFE_Context(
- const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy, bool async,
- tensorflow::DeviceMgr* local_device_mgr,
- tensorflow::Rendezvous* rendezvous,
- std::unique_ptr server,
- std::unique_ptr remote_eager_workers,
- std::unique_ptr remote_device_mgr,
- const tensorflow::gtl::FlatMap&
- remote_contexts)
- : context(opts,
- static_cast(
- default_policy),
- async, local_device_mgr, rendezvous, std::move(server),
- std::move(remote_eager_workers), std::move(remote_device_mgr),
- remote_contexts) {}
-
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 3504a8b5e78480732d3454097c1b2197ac2b2e17..71d5f3613c89762633113b4e1dfb82b8199a1cd1 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -49,7 +49,7 @@ void BM_InitOp(int iters) {
}
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) {
tensorflow::testing::StopTiming();
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -95,7 +95,7 @@ TEST(CAPI, Context) {
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const int num_devices = TF_DeviceListCount(devices);
@@ -108,14 +108,14 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
-tensorflow::ServerDef GetServerDef(int num_tasks) {
+tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
- server_def.set_job_name("localhost");
+ server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
- job_def->set_name("localhost");
+ job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
@@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return server_def;
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ return GetServerDef("localhost", num_tasks);
+}
+
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
@@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
@@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
@@ -281,7 +285,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
+void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
+ const std::vector& expected_values) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ std::unique_ptr actual_values(new float[expected_values.size()]);
+ EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
+ memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+
+ for (int i = 0; i < expected_values.size(); i++) {
+ EXPECT_EQ(expected_values[i], actual_values[i])
+ << "Mismatch in expected values at (zero-based) index " << i;
+ }
+}
+
+void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
+ const char* remote_device_name,
+ const char* local_device_name) {
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 =
+ TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
+
+ TFE_DeleteTensorHandle(retval_task0);
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+void TestRemoteExecuteChangeServerDef(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char local_device_name[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+
+ // Update the server def with a new set of names (worker instead of
+ // localhost).
+ tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
+ serialized = updated_server_def.SerializeAsString();
+
+ updated_server_def.set_task_index(1);
+ tensorflow::Status s = tensorflow::GrpcServer::Create(
+ updated_server_def, tensorflow::Env::Default(), &worker_server);
+ ASSERT_TRUE(s.ok()) << s.error_message();
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Create a new tensor_handle.
+ TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
+
+ // Check that copying it to the old remote device (named localhost) fails.
+ TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Copying and executing on the new remote device works.
+ const char new_remote_device_name[] =
+ "/job:worker/replica:0/task:1/device:CPU:0";
+ const char new_local_device_name[] =
+ "/job:worker/replica:0/task:0/device:CPU:0";
+
+ auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
+ h0_task0_new, ctx, new_remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteTensorHandle(h0_task0_new);
+ TFE_DeleteTensorHandle(h0_task1_new);
+
+ CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
+ new_local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ TFE_DeleteContext(ctx);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteChangeServerDef) {
+ TestRemoteExecuteChangeServerDef(false);
+}
+TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
+ TestRemoteExecuteChangeServerDef(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
@@ -380,8 +525,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenDevices) {
@@ -418,7 +562,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
TFE_DeleteTensorHandle(hcopy);
TFE_DeleteTensorHandle(hcpu);
if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
- TFE_DeleteContext(ctx, status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
@@ -451,7 +595,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
+ TFE_DeleteContext(ctx);
return;
}
const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
@@ -484,8 +628,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
@@ -533,8 +676,7 @@ void TensorHandleSilentCopy(bool async) {
TFE_DeleteTensorHandle(hcpu);
TFE_ContextAsyncWait(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
@@ -580,8 +722,7 @@ void TensorHandleSilentCopyLocal(bool async) {
TFE_DeleteTensorHandle(hcpu);
TFE_ContextAsyncWait(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
@@ -614,11 +755,47 @@ void SetAndGetOpDevices(bool async) {
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
+TEST(CAPI, TensorHandleNullptr) {
+ TFE_TensorHandle* h = nullptr;
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(t, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(device_name, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ int num_dims = TFE_TensorHandleNumDims(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(num_dims, -1);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ int dim = TFE_TensorHandleDim(h, 0, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(dim, -1);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+}
+
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -640,7 +817,7 @@ void Execute_MatMul_CPU(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -712,7 +889,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TFE_DeleteTensorHandle(m1);
TFE_DeleteTensorHandle(m2);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
@@ -743,7 +920,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
if (retvals[0] != nullptr) {
TFE_DeleteTensorHandle(retvals[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
@@ -781,7 +958,7 @@ TEST(CAPI, Execute_Min_CPU) {
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -823,7 +1000,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
@@ -862,7 +1039,7 @@ void Execute_Min_XLA_CPU(bool async) {
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
@@ -898,7 +1075,7 @@ void ExecuteWithTracing(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -974,7 +1151,7 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1044,7 +1221,7 @@ TEST(CAPI, Function_ident_XLA_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1120,7 +1297,7 @@ void FunctionDefAndExecute(bool async) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1161,7 +1338,7 @@ void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1249,7 +1426,7 @@ TEST(CAPI, Variables) {
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteTensorHandle(value_handle);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1288,7 +1465,7 @@ void BM_ReadVariable(int iters) {
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index a98f0b00b2c70055f697ed4f15cb14708384b62f..588a45ea43f90c4d9b3d04fea305d2c562ae1d72 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -121,6 +121,7 @@ cc_library(
deps = [
":array_grad",
":data_flow_grad",
+ ":image_grad",
":math_grad",
":nn_grad",
],
@@ -331,6 +332,36 @@ tf_cc_test(
],
)
+cc_library(
+ name = "image_grad",
+ srcs = ["gradients/image_grad.cc"],
+ deps = [
+ ":cc_ops",
+ ":cc_ops_internal",
+ ":grad_op_registry",
+ ":gradients",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "gradients_image_grad_test",
+ srcs = ["gradients/image_grad_test.cc"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":grad_op_registry",
+ ":grad_testutil",
+ ":gradient_checker",
+ ":image_grad",
+ ":testutil",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc
index ba056a8f3a84910aebf5079573cb64c19f41469d..0e61089a5950ee894ad5489317757cff8a85e966 100644
--- a/tensorflow/cc/client/client_session.cc
+++ b/tensorflow/cc/client/client_session.cc
@@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
+Status ClientSession::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
+ return impl()->session_->MakeCallable(callable_options, out_handle);
+}
+
+Status ClientSession::RunCallable(CallableHandle handle,
+ const std::vector& feed_tensors,
+ std::vector* fetch_tensors,
+ RunMetadata* run_metadata) {
+ return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status ClientSession::ReleaseCallable(CallableHandle handle) {
+ return impl()->session_->ReleaseCallable(handle);
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h
index 5fb4109f7d15d5997f745acd913e60a02855fd73..7dd653eec4ec729b652cb779d06e820bfb437b3c 100644
--- a/tensorflow/cc/client/client_session.h
+++ b/tensorflow/cc/client/client_session.h
@@ -87,7 +87,33 @@ class ClientSession {
const std::vector& run_outputs,
std::vector* outputs, RunMetadata* run_metadata) const;
- // TODO(keveman): Add support for partial run.
+ /// \brief A handle to a subgraph, created with
+ /// `ClientSession::MakeCallable()`.
+ typedef int64 CallableHandle;
+
+ /// \brief Creates a `handle` for invoking the subgraph defined by
+ /// `callable_options`.
+ /// NOTE: This API is still experimental and may change.
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle);
+
+ /// \brief Invokes the subgraph named by `handle` with the given options and
+ /// input tensors.
+ ///
+ /// The order of tensors in `feed_tensors` must match the order of names in
+ /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
+ /// match the order of names in `CallableOptions::fetch()` when this subgraph
+ /// was created.
+ /// NOTE: This API is still experimental and may change.
+ Status RunCallable(CallableHandle handle,
+ const std::vector& feed_tensors,
+ std::vector* fetch_tensors,
+ RunMetadata* run_metadata);
+
+ /// \brief Releases resources associated with the given `handle` in this
+ /// session.
+ /// NOTE: This API is still experimental and may change.
+ Status ReleaseCallable(CallableHandle handle);
private:
class Impl;
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
index ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400..559ffea7e817526e7f1396cd0e8187d01364f23b 100644
--- a/tensorflow/cc/client/client_session_test.cc
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2}));
}
+TEST(ClientSessionTest, Callable) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto b = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, b);
+ ClientSession session(root);
+ std::vector outputs;
+
+ CallableOptions options;
+ options.add_feed(a.node()->name());
+ options.add_feed(b.node()->name());
+ options.add_fetch(c.node()->name());
+ ClientSession::CallableHandle callable;
+ TF_CHECK_OK(session.MakeCallable(options, &callable));
+ TF_EXPECT_OK(session.RunCallable(
+ callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})},
+ &outputs, nullptr));
+ test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {}));
+ TF_EXPECT_OK(session.ReleaseCallable(callable));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index de2645cb440bda1f35e764af9197ca97bb760c08..e9f9c59e3aa0e8a9dc5d5e658540e9da73adaca5 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_pos_flat = y_pos[y_idx].flat();
auto y_neg_flat = y_neg[y_idx].flat();
const int64 y_size = y_shapes[y_idx].num_elements();
- const Y_T scale = Y_T{2 * delta};
+ const Y_T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix();
for (int c = 0; c < y_size; ++c) {
SetJacobian(&jacobian, r * x_stride + unit_dimension,
@@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
auto jac_n = jacobian_ns[i].matrix();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
- *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
+ auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c));
+ // Treat any NaN as max_error and immediately return.
+ // (Note that std::max may ignore NaN arguments.)
+ if (std::isnan(cur_error)) {
+ *max_error = cur_error;
+ return Status::OK();
+ }
+ *max_error = std::max(*max_error, cur_error);
}
}
}
@@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
+INSTANTIATE_GRAD_ERR_TYPE(double, float, double);
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);
diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc
index d4f0a7f5ab3716be41e22c02a21aca028f76fb88..8dd762c282eff287bddd49ea6f38b2b8060949b0 100644
--- a/tensorflow/cc/framework/gradient_checker_test.cc
+++ b/tensorflow/cc/framework/gradient_checker_test.cc
@@ -28,12 +28,14 @@ namespace {
using ops::Complex;
using ops::Const;
+using ops::Div;
using ops::MatMul;
using ops::Placeholder;
using ops::Real;
using ops::Split;
using ops::Square;
using ops::Stack;
+using ops::Sub;
using ops::Unstack;
TEST(GradientCheckerTest, BasicFloat) {
@@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) {
EXPECT_LT(max_error, 1e-4);
}
+// When calculating gradients that are undefined, test we get NaN
+// as the computed error rather than 0.
+TEST(GradientCheckerTest, BasicNan) {
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
+ // y = x/(x-x) should always return NaN
+ auto y = Div(scope, x, Sub(scope, x, x));
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_TRUE(std::isnan(max_error));
+}
+
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index b353accddcb6db9a07c112de03ead2f02c4ee6a6..e9173227aadbf86eab666e6c17bacacb92888572 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Split", SplitGrad);
+Status FillGrad(const Scope& scope, const Operation& op,
+ const std::vector