diff --git a/.gitignore b/.gitignore
index 5afe375f46f07b3b557ae23f75740b337517d3bd..1ef4c297ee4f369775c13b32a46a55887de719e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@ __pycache__
*.swp
.vscode/
cmake_build/
+tensorflow/contrib/cmake/_build/
.idea/**
/build/
[Bb]uild/
@@ -30,6 +31,7 @@ Podfile.lock
xcuserdata/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
+*.whl
# Android
.gradle
diff --git a/CODEOWNERS b/CODEOWNERS
index b9f0313cc6d59d3fbdcd014e1a528126d863075a..113eaf798f7a0abf1a9ad3fed6308f234f8efe75 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,53 +1,62 @@
-# NOTE: Disabled temporarily because it's too noisy on pushes.
# Where component owners are known, add them here.
-# /tensorflow/core/platform/windows/ @mrry
-# /tensorflow/java/ @asimshankar
-# /tensorflow/tensorboard/ @jart @dandelionmane
-# /tensorflow/tools/docs/ @markdaoust
+/tenosrflow/core/debug @caisq
+/tensorflow/core/platform/windows/ @mrry
+/tensorflow/go @asimshankar
+/tensorflow/java/ @asimshankar
+/tensorflow/python/debug @caisq
+/tensorflow/python/tools/api/generator/ @annarev
+/tensorflow/tensorboard/ @jart
+/tensorflow/tools/docs/ @markdaoust
# contrib
-# NEED OWNER: /tensorflow/contrib/avro/
-# /tensorflow/contrib/batching/ @alextp @chrisolston
-# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
-# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
-# /tensorflow/contrib/cmake/ @mrry @benoitsteiner
-# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi
-# /tensorflow/contrib/crf/ @kentonl
-# /tensorflow/contrib/data/ @mrry
-# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
-# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
-# /tensorflow/contrib/ffmpeg/ @fredbertsch
-# NEED OWNER: /tensorflow/contrib/framework/
-# /tensorflow/contrib/graph_editor/ @purpledog
+# NEED OWNER: /tensorflow/contrib/all_reduce
+/tensorflow/contrib/batching/ @alextp @chrisolston
+/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
+/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
+/tensorflow/contrib/checkpoint/ @allenlavoie
+/tensorflow/contrib/contrib/cluster_resolver/ @frankchn
+/tensorflow/contrib/cmake/ @mrry
+/tensorflow/contrib/copy_graph/ @tucker @poxvoculi
+/tensorflow/contrib/crf/ @kentonl
+/tensorflow/contrib/data/ @mrry
+/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
+/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
+/tensorflow/contrib/eager @alextp @asimshankar
+/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
+/tensorflow/contrib/ffmpeg/ @fredbertsch
+/tensorflow/contrib/framework/ @ebrevdo
+/tensorflow/contrib/gan/ @joel-shor
+/tensorflow/contrib/graph_editor/ @purpledog
# NEED OWNER: /tensorflow/contrib/grid_rnn/
-# /tensorflow/contrib/hvx/ @satok16
-# /tensorflow/contrib/integrate/ @shoyer
-# /tensorflow/contrib/kernel_methods/ @petrosmol
-# /tensorflow/contrib/ios_examples/ @petewarden
-# /tensorflow/contrib/labeled_tensor/ @shoyer
-# /tensorflow/contrib/layers/ @fchollet @martinwicke
-# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
-# /tensorflow/contrib/linalg/ @langmore
-# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
-# /tensorflow/contrib/lookup/ @ysuematsu @andreasst
-# /tensorflow/contrib/losses/ @alextp @ispirmustafa
-# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
-# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
-# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq
-# /tensorflow/contrib/opt/ @strategist333
-# /tensorflow/contrib/pi_examples/ @maciekcc
-# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman
-# /tensorflow/contrib/rnn/ @ebrevdo
-# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh
-# /tensorflow/contrib/seq2seq/ @lukaszkaiser
-# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
-# /tensorflow/contrib/slim/ @sguada @thenbasilmanran
-# /tensorflow/contrib/stateless/ @girving
-# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
-# /tensorflow/contrib/testing/ @dandelionmane
-# /tensorflow/contrib/timeseries/ @allenlavoie
-# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu
-# /tensorflow/contrib/training/ @joel-shor @ebrevdo
-# /tensorflow/contrib/util/ @sherrym
+/tensorflow/contrib/hvx/ @satok16
+/tensorflow/contrib/integrate/ @shoyer
+/tensorflow/contrib/kernel_methods/ @petrosmol
+/tensorflow/contrib/ios_examples/ @petewarden
+/tensorflow/contrib/labeled_tensor/ @shoyer
+/tensorflow/contrib/layers/ @fchollet @martinwicke
+/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
+/tensorflow/contrib/linalg/ @langmore
+/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
+/tensorflow/contrib/lookup/ @ysuematsu @andreasst
+/tensorflow/contrib/losses/ @alextp @ispirmustafa
+/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
+/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
+/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq
+/tensorflow/contrib/opt/ @strategist333 @alextp
+/tensorflow/contrib/pi_examples/ @maciekcc
+/tensorflow/contrib/quantization/ @petewarden
+/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
+/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl
+/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
+/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
+/tensorflow/contrib/slim/ @sguada @thenbasilmanran
+/tensorflow/contrib/stateless/ @girving @alextp
+/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
+/tensorflow/contrib/tensorrt/ @laigd
+# NEED OWNER: /tensorflow/contrib/testing/
+/tensorflow/contrib/timeseries/ @allenlavoie
+/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
+/tensorflow/contrib/training/ @joel-shor @ebrevdo
+/tensorflow/contrib/util/ @sherrym
\ No newline at end of file
diff --git a/README.md b/README.md
index 16d354ca7b150814f11fd825d6a22c84cebc2a01..91f49f8e95cc25fc9bd052ccd13a3c1cae232740 100644
--- a/README.md
+++ b/README.md
@@ -100,16 +100,16 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
| **Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
-| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
+| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) |
## For more information
-* [Tensorflow Blog](https://medium.com/tensorflow)
+* [TensorFlow Blog](https://medium.com/tensorflow)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [Tensorflow Twitter](https://twitter.com/tensorflow)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index b8adf6c1279e72d0c2056368253aa0cb470216e5..173bbea596a4276559f5cd67824e5cc75313985c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1240,7 +1240,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
const char* value, size_t length) {
tensorflow::NameAttrList func_name;
- func_name.set_name(std::string(value, value + length));
+ func_name.set_name(string(value, value + length));
desc->node_builder.Attr(attr_name, func_name);
}
@@ -2065,7 +2065,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
- tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
+ tf_results->missing_unused_key_names_data.emplace_back(id.first);
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index aa2a537f03be31ae45ff3d6f7815b449d661cf9c..03516c39dc970aa23967107d3a0446da94669465 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) {
TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
nullptr, 0, run_metadata, s);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(std::string("Session was not created with a graph before Run()!"),
- std::string(TF_Message(s)));
+ EXPECT_EQ("Session was not created with a graph before Run()!",
+ string(TF_Message(s)));
TF_DeleteBuffer(run_metadata);
TF_DeleteBuffer(run_options);
@@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test {
TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
if (expected.empty()) {
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
- EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."),
- std::string(TF_Message(s_)));
+ EXPECT_EQ("Operation 'add' has no attr named '_class'.",
+ string(TF_Message(s_)));
return;
}
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
@@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) {
input.flat()(i) = example.SerializeAsString();
}
- const tensorflow::string input_op_name =
- std::string(tensorflow::ParseTensorName(input_name).first);
+ const tensorflow::string input_op_name(
+ tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- const tensorflow::string output_op_name =
- std::string(tensorflow::ParseTensorName(output_name).first);
+ const tensorflow::string output_op_name(
+ tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index 74bc25a491ac01cb725d1c004197e48727c30230..d3311f0cd06f2b151c3567735eb41b5baf72e102 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
- std::string(v2_reader_->key()) /* full var's name */,
+ string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
@@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
- if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
+ if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
- string key = std::string(v2_reader_->key());
+ string key(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
old mode 100644
new mode 100755
index dfb1c9a37644c726e1eabab775593596d5b556b9..1ccae3f138920b1908f18387ea87b11388115d37
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -244,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
}
void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
- unsigned char async) {
- options->async = async;
+ unsigned char enable) {
+ options->async = enable;
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
@@ -253,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
}
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
- unsigned char async,
+ unsigned char enable,
TF_Status* status) {
- status->status = ctx->context.SetAsyncForThread(async);
+ status->status = ctx->context.SetAsyncForThread(enable);
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
old mode 100644
new mode 100755
index a0ebc6fa0a22ed61be91c2974352c2988fb4cd92..eec2750d6eb3bceed8da3ed44812ac2e8fd5c877
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetAsyncForThread.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
- unsigned char async);
+ unsigned char enable);
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
@@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
// Overrides the execution mode (sync/async) for the current thread.
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
- unsigned char async,
+ unsigned char enable,
TF_Status* status);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index c20ea95a15e3f53b9b26716ed7b624fa853017c9..a32d1b1eb50fc715084f5ee663a732770db1883c 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
- return std::string(name);
+ return string(name);
}
void InferArgAttributes(const OpDef::ArgDef& arg,
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 8c886f31711eb014fb9e9d600c9c78cf22073f71..7f6ac4cae78d8d6e118837fce9ae5270336cdc89 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -225,7 +225,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
- current_constraints.insert(std::string(s));
+ current_constraints.emplace(s);
}
}
} else {
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index 3830416159158cca8bfb8422c2959b49fa42406d..c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return RunOnce(run_options, inputs, {}, {main_op_name.ToString()},
+ return RunOnce(run_options, inputs, {}, {string(main_op_name)},
nullptr /* outputs */, &run_metadata, session);
}
return Status::OK();
@@ -182,12 +182,12 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
variables_path_tensor.scalar()() = variables_path;
std::vector> inputs = {
- {variable_filename_const_op_name.ToString(), variables_path_tensor}};
+ {string(variable_filename_const_op_name), variables_path_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
- return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()},
+ return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
nullptr /* outputs */, &run_metadata, session);
}
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 2220d0786d3757abc378d1a3d0ddc704bba6a4f3..59b961cdd9dac8a1c305a3f5f520ca1b68148cca 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -32,7 +32,6 @@ cc_library(
deps = [
":embedded_protocol_buffers",
"//tensorflow/compiler/tf2xla",
- "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@@ -56,6 +55,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -72,6 +72,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
],
@@ -100,6 +101,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -195,6 +197,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 44291d977f8e97bdcba8131363e65956cad60cb7..e77a8fecf09fa037726b0baf5d2f38aeae0ef155 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,9 +20,10 @@ limitations under the License.
#include
#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
}
rewrites->push_back({"{{I}}", strings::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type});
- rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
+ rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
return Status::OK();
@@ -158,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
// text-templating mechanism.
string RewriteWithName(const string& name, string code,
const std::vector>& rewrites) {
- str_util::ReplaceAllPairs(&code, rewrites);
- return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
+ absl::StrReplaceAll(rewrites, &code);
+ absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
+ return code;
}
// Generate methods for args (inputs).
@@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
- {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
+ {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
+ absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
@@ -595,8 +596,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}",
- str_util::Join(buffer_infos_as_strings, ",\n")}};
- str_util::ReplaceAllPairs(header, rewrites);
+ absl::StrJoin(buffer_infos_as_strings, ",\n")}};
+ absl::StrReplaceAll(rewrites, header);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 60d59ae996e8f7ec490c98aeab05182626e61976..e3a53edb7368c209bea16a9e34b1f452a8ff4bf8 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include
#include
+#include "absl/strings/match.h"
#include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -34,9 +34,9 @@ namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 8fb2fad31c680c5dbbd058a1b9a9265607224429..1401aae7586bfd40ec209b0ae591d6ab69d0a26b 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include
#include "absl/memory/memory.h"
+#include "absl/strings/str_replace.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
" return proto;\n"
" }()";
- str_util::ReplaceAllPairs(
- &code,
+ return absl::StrReplaceAll(
+ code,
{
{"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)},
{"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)},
{"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)},
});
- return code;
}
static StatusOr CodegenModule(llvm::TargetMachine* target_machine,
@@ -97,7 +96,7 @@ static StatusOr>
GetTargetMachineFromTriple(StringPiece target_triple) {
std::string error;
std::string normalized_triple =
- llvm::Triple::normalize(AsStringRef(target_triple));
+ llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(normalized_triple, error);
if (target == nullptr) {
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 0ecc3feeb6fef1dd691ab2785b3221075a79ba88..723e9bec8afcfbf7ceeeb59c63e4e12442fdb7ab 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -187,6 +187,9 @@ tf_library(
cpp_class = "MatMulAndAddCompWithProfiling",
enable_xla_hlo_profiling = True,
graph = "test_graph_tfmatmulandadd.pb",
+ tags = [
+ "manual",
+ ],
)
tf_library(
@@ -226,5 +229,6 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 0c0c676ece78565e03578d3e33633c7e23b77669..dd2b151098f2054571ac32b8b506cbc00659588a 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
+#include "absl/strings/str_split.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) {
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
std::vector hlo_profile_lines =
- tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+ absl::StrSplit(hlo_profile_as_string, '\n');
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 839e1588b7be6c91cf30c87bbaf75402446bd169..f3c44e9dda8ce96a268420a7f4d0f22e50ddfe41 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include
#include
+#include "absl/strings/match.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@@ -34,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,7 +56,7 @@ const char kUsageHeader[] =
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
- if (str_util::EndsWith(fname, ".pbtxt")) {
+ if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
@@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
- std::cout << str_util::Join(nodes, ",");
+ std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 2466c218c82dbd504043dbfff70fb3ba88d38e3b..df81f3c23e38a2ec2cea827cd0adb123855e7714 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -310,6 +310,51 @@ tf_cc_test(
],
)
+cc_library(
+ name = "resource_operation_safety_analysis",
+ srcs = ["resource_operation_safety_analysis.cc"],
+ hdrs = ["resource_operation_safety_analysis.h"],
+ deps = [
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_safety_analysis_test",
+ srcs = ["resource_operation_safety_analysis_test.cc"],
+ deps = [
+ ":common",
+ ":resource_operation_safety_analysis",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
cc_library(
name = "compilation_passes",
srcs = [
@@ -335,11 +380,10 @@ cc_library(
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
- "//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
- "//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -351,6 +395,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/strings",
],
)
@@ -359,6 +404,7 @@ cc_library(
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
+ ":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
@@ -437,6 +483,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
@@ -448,6 +495,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -528,6 +576,9 @@ tf_cuda_cc_test(
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 1b1ce78ed2b79d0948b6fc951f82a2cebe8009e5..56b034a30b7bddb023e54ead22c91a7a18095d2d 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -126,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
@@ -208,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
- // in device memory
+ // in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ for (int i = 0; i < fbody->ret_types.size(); ++i) {
+ if (fbody->ret_types[i] == DT_RESOURCE) {
+ output_memory_types[i] = HOST_MEMORY;
+ }
+ }
// Create the kernel.
NameAttrList function;
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 0ca0f949dcd13992ccd9504d75ca65d2aff72a19..fe28502f69d34e7c075bdf85afd2473024b4081d 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
@@ -153,7 +154,7 @@ class AndPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
+ return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
}
Kind kind() const override { return Kind::kAnd; }
@@ -182,7 +183,7 @@ class OrPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
+ return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
}
Kind kind() const override { return Kind::kOr; }
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index cc9f1023985560be0bce5971931d2ec8e742b377..28a56044d5e3795fc3ecf5d1092491b87cb90f01 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f150bf1819d407e1c6a279673a89de4307b5426b..2788102620546d8eab657c519f078c5b03e265cc 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -2504,7 +2504,8 @@ Status EncapsulateSubgraphsPass::Run(
const int num_args = input_permutation->size();
std::vector const_args(num_args);
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr));
DataTypeVector arg_types(num_args);
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index c0543a00792235c5dd090e81930d8c219dc7f1a3..b3600fc48b9daa0e901e2b01cdc121aef0a1e8af 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
std::unordered_set control_input_a;
std::unordered_set control_input_b;
for (int i = 0; i < a.input_size(); ++i) {
- if (str_util::StartsWith(a.input(i), "^")) {
- if (!str_util::StartsWith(b.input(i), "^")) {
+ if (absl::StartsWith(a.input(i), "^")) {
+ if (!absl::StartsWith(b.input(i), "^")) {
if (diff) {
*diff = strings::StrCat(
diff_preamble, " mismatch for node ", a.name(), " input ", i,
@@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
@@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 8f78c110cb15f3cbc0344d102764241996b0d7de..253a5d254792a19d98b75310ea6848f42597c0c7 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -29,16 +29,3 @@ cc_library(
],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- visibility = ["//tensorflow/compiler/jit:friends"],
- deps = [
- "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- ],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc
deleted file mode 100644
index bd4eefbc0bb960f8ddc1d238057e73a29a098f26..0000000000000000000000000000000000000000
--- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace tensorflow {
-namespace {
-
-// Inputs 2*N tensors, outputs the first N inputs.
-// Logs errors if input tensor i and i + N are not (near) identical
-// in any position.
-class ParallelCheckOp : public OpKernel {
- public:
- explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- template
- int CompareTensors(DataType dtype, const char* v0, const char* v1,
- int64 num_elts, int input_idx) {
- int failed = 0;
- const T* p0 = reinterpret_cast(v0);
- const T* p1 = reinterpret_cast(v1);
- double rtol;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
- &rtol)) {
- LOG(ERROR) << "can't convert parallel_check_rtol "
- << flags->parallel_check_rtol << " to double";
- }
- double atol;
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
- &atol)) {
- LOG(ERROR) << "can't convert parallel_check_atol "
- << flags->parallel_check_atol << " to double";
- }
- for (int i = 0; i < num_elts; ++i) {
- bool ok = (p0[i] == p1[i]);
- VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
- if (!ok) {
- if (std::is_same::value || std::is_same::value) {
- float tolerance =
- std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
- T diff = p0[i] - p1[i];
- if (diff < 0) diff = 0 - diff;
- ok = (diff <= tolerance);
- }
- if (ok) continue;
- LOG(ERROR) << "Op " << name() << " fails equality at output "
- << input_idx << " type " << DataTypeString(dtype)
- << " element " << i << ": std_val=" << p0[i]
- << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
- if (++failed > 10) break;
- }
- }
- return failed;
- }
-
- void Compute(OpKernelContext* ctx) override {
- VLOG(1) << "Compute " << name();
- const int num_pairs = ctx->num_inputs() / 2;
- for (int i = 0; i < num_pairs; ++i) {
- CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
- Tensor t0 = ctx->input(i);
- Tensor t1 = ctx->input(i + num_pairs);
- int64 num_elts = t0.NumElements();
- CHECK_EQ(num_elts, t1.NumElements());
-
- // Compare inputs elementwise for near-exact equality.
- const char* v0 = t0.tensor_data().data();
- const char* v1 = t1.tensor_data().data();
- int failed = 0;
- switch (ctx->input_dtype(i)) {
- case DT_INT32:
- failed =
- CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_INT64:
- failed =
- CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_FLOAT:
- failed =
- CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_DOUBLE:
- failed =
- CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_BOOL:
- failed =
- CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- default:
- LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
- }
- if (failed > 0) {
- LOG(ERROR) << "check failed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (flags->parallel_check_failfast) {
- LOG(QFATAL) << "failfast on first parallel-check failure";
- }
- } else {
- VLOG(1) << "check passed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- }
-
- // Propagate the std value.
- if (IsRefType(ctx->input_dtype(i))) {
- ctx->forward_ref_input_to_ref_output(i, i);
- } else {
- ctx->set_output(i, ctx->input(i));
- }
- }
- }
-
- TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
-};
-
-REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
- ParallelCheckOp);
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index ddb27a38ae3b6749a82f86ba8be88ec68e733006..fde4135bf7f5f7bdede170d47fb2a76d1d6b3ae9 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -187,7 +187,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, &compile_options));
+ &kernel, &executable, compile_options));
VLOG(1) << "Executing XLA Computation...";
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 11bd5eec238c0e542814f22bc7a33a90abd0ec28..4e4abade3278089a1c7f8fdee46a34b8ce503651 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -27,7 +27,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -74,18 +77,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
+bool HasResourceOutput(const Node& node) {
+ return std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+bool HasResourceInput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end();
+}
+
+// Returns true if `node` is a resource operation recognized by tf2xla that
+// operates on something other than resource variables.
+bool IsNonResourceVarResourceOp(const Node& node) {
+ // TODO(b/112837194): We can't cluster these because we only support
+ // snapshotting resource variables (and we can't e.g. snapshot stacks). This
+ // limitation may be fixable with some work.
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
+ return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+}
+
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const Node& while_node,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
@@ -100,7 +125,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
@@ -115,7 +141,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
@@ -127,7 +154,8 @@ bool IsCompilableWhile(const Node& while_node,
// Every operator in the function must be compilable for a function to be
// compilable.
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
@@ -167,12 +195,17 @@ bool IsCompilableCall(const NodeDef& call_def,
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
continue;
if (node->type_string() == "While") {
- // Handle functional While loop (not in open source build).
- return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
+ // Handle functional While loop.
+ return IsCompilableWhile(*node, jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime);
+ }
+ if (!allow_resource_ops &&
+ (HasResourceInput(*node) || HasResourceOutput(*node))) {
+ return false;
}
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, depth + 1,
- lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
<< node->name() << ": " << node->def().ShortDebugString();
return false;
@@ -343,6 +376,10 @@ Status FindCompilationCandidates(
flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+ std::vector compile_time_const_nodes(graph.num_node_ids(), false);
+ TF_RETURN_IF_ERROR(
+ BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
+ &compile_time_const_nodes));
int64& fuel =
legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
@@ -386,19 +423,46 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type,
+ registration->compile_resource_ops, 0, lib_runtime)) {
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
<< node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
- HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
+ // We don't have a way of returning values of type DT_RESOURCE from XLA
+ // computations so we avoid auto-clustering nodes producing DT_RESOURCE.
+ // XlaLaunchOp also cannot snapshot resources that are not resource
+ // variables so we avoid clustering resource operations that operate on
+ // non-resource variables.
+ VLOG(2) << "Rejecting: " << node->name() << ": resource output "
<< node->type_string();
continue;
}
+ if (compile_time_const_nodes[node->id()] &&
+ !registration->requires_compilation) {
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(
+ OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def));
+ if (op_def->is_stateful()) {
+ // We need to be able to constant fold the nodes in
+ // compile_time_const_nodes given constant inputs (required by XLA) and
+ // therefore can't auto-cluster stateful ops since these can never be
+ // constant folded.
+ VLOG(2) << "Rejecting " << node->name()
+ << ": must-be-constant stateful op";
+ continue;
+ }
+ }
+ // We don't auto-cluster functional control flow nodes containing resource
+ // operations because safety checks are trickier in this case.
+ // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
+ // for CPU/GPU.
if (node->type_string() == "While" &&
- !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
+ !IsCompilableWhile(*node, jit_device_type,
+ registration->compile_resource_ops, 0,
+ lib_runtime)) {
continue;
}
// _Arg nodes in a top-level function represent feeds.
@@ -457,7 +521,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
®istration));
DeviceType jit_device_type(registration->compilation_device_name);
- return IsCompilableCall(ndef, jit_device_type, 0, flr);
+
+ // We can always *compile* resource operations, even if we are sometimes
+ // unable to auto-cluster them.
+ const bool compile_resource_ops = true;
+ return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr);
}
Status MarkForCompilationPass::Run(
@@ -600,6 +668,82 @@ static void VLogClusteringSummary(const Graph& g) {
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
}
}
+
+ struct EdgeInfo {
+ StringPiece node_name;
+ absl::optional cluster_name;
+
+ StringPiece GetClusterName() const {
+ return cluster_name ? *cluster_name : "[none]";
+ }
+
+ std::pair> AsPair() const {
+ return {node_name, cluster_name};
+ }
+
+ bool operator<(const EdgeInfo& other) const {
+ return AsPair() < other.AsPair();
+ }
+ };
+
+ using EdgeInfoMap = std::map>;
+
+ EdgeInfoMap incoming_edge_infos;
+ EdgeInfoMap outgoing_edge_infos;
+
+ std::set cluster_names_to_print;
+
+ for (const Edge* e : g.edges()) {
+ const Node* from = e->src();
+ absl::optional from_cluster_name = GetXlaClusterForNode(*from);
+
+ const Node* to = e->dst();
+ absl::optional to_cluster_name = GetXlaClusterForNode(*to);
+
+ if (to_cluster_name == from_cluster_name) {
+ continue;
+ }
+
+ if (to_cluster_name) {
+ incoming_edge_infos[*to_cluster_name]
+ [EdgeInfo{from->name(), from_cluster_name}]++;
+ cluster_names_to_print.insert(*to_cluster_name);
+ }
+
+ if (from_cluster_name) {
+ outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
+ cluster_names_to_print.insert(*from_cluster_name);
+ }
+ }
+
+ VLOG(2) << "*** Inter-Cluster edges:";
+ if (cluster_names_to_print.empty()) {
+ VLOG(2) << " [none]";
+ }
+
+ auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name,
+ const EdgeInfoMap& edge_info_map,
+ StringPiece desc) {
+ auto it = edge_info_map.find(cluster_name);
+ if (it != edge_info_map.end()) {
+ VLOG(2) << " " << it->second.size() << " " << desc << " edges";
+ for (const auto& edge_info_count_pair : it->second) {
+ VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " "
+ << edge_info_count_pair.first.node_name << " # "
+ << edge_info_count_pair.second;
+ }
+ } else {
+ VLOG(2) << " No " << desc << " edges.";
+ }
+ };
+
+ for (StringPiece cluster_name : cluster_names_to_print) {
+ VLOG(2) << " ** Cluster " << cluster_name;
+ print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
+ "incoming");
+ print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
+ "outgoing");
+ }
}
// Is 'node' an operator that consumes only the shape of its input, not the
@@ -609,6 +753,43 @@ static bool IsShapeConsumerOp(const Node& node) {
node.type_string() == "Size";
}
+static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
+ // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
+ // ignore it during resource operation safety analysis. We need this hack
+ // because of two reasons:
+ //
+ // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
+ // 2. We don't support live-out values of type DT_RESOURCE and live-in values
+ // of type DT_RESOURCE that are not resource variables.
+ //
+ // Together these imply we cannot let resource variable safety analysis
+ // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
+ // clusters: both of them will have to be clustered because of (1) and we
+ // won't be able to keep the edge between the two as neither the input to the
+ // second XLA cluster nor the output from the first XLA cluster are supported
+ // because of (2).
+ //
+ // TODO(b/113100872): This can be fixed if the TensorFlow representation for
+ // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
+ // (2) would no longer hold.
+
+ if (n.assigned_device_name().empty()) {
+ *ignore = false;
+ return Status::OK();
+ }
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n.assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
+ *ignore = true;
+ } else {
+ *ignore = registration->compile_resource_ops;
+ }
+ return Status::OK();
+}
+
// Sequence number generator to ensure clusters have unique names.
static std::atomic cluster_sequence_num;
@@ -637,6 +818,8 @@ Status MarkForCompilationPass::RunImpl(
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@@ -675,7 +858,7 @@ Status MarkForCompilationPass::RunImpl(
string to_scope;
for (int to : cycles.Successors(from)) {
if (to >= graph->num_node_ids()) {
- // Node is a "frame" node that is present only in the cycle detection
+ // Node is a fictitious node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9d7ac0d609eea370b8100e1eb53b0b0b3d9f2382..807ab51fd3c133b95915ea88e0bf99dbb8661452 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -26,11 +28,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -48,9 +50,35 @@ std::unordered_map GetClusters(const Graph& graph) {
ids[node->name()] = cluster;
}
}
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Clusters:";
+ for (const auto& p : ids) {
+ VLOG(2) << " " << p.first << " -> " << p.second;
+ }
+ }
return ids;
}
+gtl::FlatMap> GetClusterSets(
+ const Graph& g, std::vector* cluster_names = nullptr) {
+ CHECK(cluster_names == nullptr || cluster_names->empty());
+ gtl::FlatMap> cluster_sets;
+ for (const auto& p : GetClusters(g)) {
+ cluster_sets[p.second].push_back(p.first);
+ }
+ for (auto& p : cluster_sets) {
+ if (cluster_names != nullptr) {
+ cluster_names->push_back(p.first);
+ }
+ std::sort(p.second.begin(), p.second.end());
+ }
+ if (cluster_names != nullptr) {
+ std::sort(cluster_names->begin(), cluster_names->end());
+ }
+ return cluster_sets;
+}
+
TEST(XlaCompilationTest, Chains) {
std::unique_ptr graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
@@ -501,38 +529,104 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
EXPECT_EQ(clusters["B"], clusters["C"]);
}
-REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float");
-REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource");
-
namespace {
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
-class DummyOp : public XlaOpKernel {
- using XlaOpKernel::XlaOpKernel;
- void Compile(XlaOpKernelContext* ctx) override {}
-};
-
-REGISTER_XLA_OP(Name("ResourceInput"), DummyOp);
-REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp);
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
+ var_handle, value_to_write);
+ return assign_op.operation.node();
+}
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
} // namespace
-TEST(XlaCompilationTest, Resources) {
+TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ FixupSourceAndSinkEdges(root.graph());
std::unique_ptr graph(new Graph(OpRegistry::Global()));
- GraphDef graphdef;
- {
- GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
- Node* a =
- ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
- Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
- // We should not form clusters with resource ops by default.
- Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
- Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
- ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
- TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
- }
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- auto clusters = GetClusters(*graph);
- EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
+ gtl::FlatMap> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector expected_clustered_nodes = {"AssignmentW", "ReadR",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+ gtl::FlatMap> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector expected_clustered_nodes = {"AssignmentW",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::vector cluster_names;
+ gtl::FlatMap> cluster_sets =
+ GetClusterSets(*graph, &cluster_names);
+
+ ASSERT_EQ(cluster_sets.size(), 2);
+
+ std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
+ "ValueToAssignW0"};
+ ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
+
+ std::vector expected_clustered_nodes_b = {
+ "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
+ ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
}
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
@@ -562,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.ToString(),
- "Edge from c to a would create a cycle.\n"
- "+-> a\n"
- "| b\n"
- "+-- c\n"));
+ EXPECT_TRUE(absl::StrContains(status.ToString(),
+ "Edge from c to a would create a cycle.\n"
+ "+-> a\n"
+ "| b\n"
+ "+-- c\n"));
}
TEST(XlaCompilationTest, Retval) {
@@ -731,5 +825,27 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
EXPECT_EQ(clusters, expected_clusters);
}
+TEST(XlaCompilationTest, RandomShape) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
+ ops::Const(root.WithOpName("minval"), 1),
+ ops::Const(root.WithOpName("maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["shape"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index c9e46bc1475aed0e35a48765ad70eef4362e8281..13804c6a0575b921839f99ef7d142e0871693b5a 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -10,10 +10,3 @@ cc_library(
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- deps = ["//tensorflow/core:framework"],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 08a956e4c6478ff76a0fe8f1f60d94824daf535c..f61a955c222dd7ce11a177cd54bb8851a5400496 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1ba4a5ef7399111e512da8c4966f5899ed828b17
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -0,0 +1,336 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// ALGORITHM OVERVIEW
+// ==================
+//
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// computes the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+// dependencies.
+//
+// Specifically the result computed by this analysis contains the edge {W, R}
+// iff all of these hold true:
+//
+// - In the graph (g - {edges from NextIteration to Merge}) there is a path
+// from W to R.
+// - IsEdgeSafe(W, R) == False [defined below]
+// - W != R (note: some resource operations both read from and write to
+// resource variables).
+//
+// The result is incorrect around loops because we ignore edges from
+// NextIteration to Merge, but that should be fine because we don't cluster
+// these edges. For instance, in:
+//
+// Init -----> Merge <-------+
+// | |
+// v |
+// Read |
+// | |
+// v |
+// Write |
+// | |
+// v |
+// NextIteration --+
+//
+// we won't put (Read, Write) in the returned set. This is fine if
+// auto-clustering can only cluster the Read->Write edge, but it is a problem if
+// it clusters the Write->NextIteration->Merge->Read edges instead. The same
+// problem is present for the functional version of the loop above. We rely on
+// auto-clustering to not cluster control flow edges like NextIteration->Merge.
+// This is enough to avoid the explicit-control-flow problem shown above. One
+// way to think about this is that we only care about cases where two nodes, A
+// and B, would normally have been put in the same cluster but cannot legally be
+// in the same cluster because of resourcevar-dependencies. If A and B would
+// normally have been put in the same cluster then all paths between A and B
+// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo
+// there could not have been a NextIteration->Merge edge between A and B since
+// we don't cluster these edges.
+//
+// We also rely on auto-clustering to not cluster functional control flow nodes
+// that contain resource operations.
+//
+// IMPLEMENTATION
+// --------------
+//
+// We traverse the graph minus backedges in reverse post order, mapping each
+// node to the set of resource operation reaching that node. Since we visit
+// producers before consumers, we can construct the set of reaching operations
+// by taking the union of the operations reaching the input nodes. These
+// "reaching resource operations" can then be used to create the pairs of
+// incompatible nodes using `IsEdgeSafe`.
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+// Returns true if `n` may call a function.
+Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
+ bool* out_result) {
+ if (flib_def->Contains(n.type_string())) {
+ *out_result = true;
+ } else {
+ *out_result =
+ std::any_of(n.def().attr().begin(), n.def().attr().end(),
+ [](const std::pair& name_attr_pair) {
+ return name_attr_pair.second.has_func();
+ });
+ }
+
+ return Status::OK();
+}
+
+// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
+// not a resource operation recognized by XLA then sets `out_resource_op_kind`
+// to nullopt.
+Status XlaResourceOpKindForNode(
+ const Node& n, const FunctionLibraryDefinition* flib_def,
+ const std::function& resource_ops_to_ignore,
+ absl::optional* out_resource_op_kind) {
+ bool should_ignore = false;
+ if (resource_ops_to_ignore) {
+ TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
+ }
+ if (should_ignore) {
+ *out_resource_op_kind = absl::nullopt;
+ return Status::OK();
+ }
+
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
+ if (op_info) {
+ *out_resource_op_kind = op_info->kind();
+ return Status::OK();
+ }
+
+ // We conservatively assume that functions will both read and write resource
+ // variables. In the future we may consider doing some form of
+ // inter-procedural analysis.
+ bool may_call_function;
+ TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
+ if (may_call_function) {
+ *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
+ } else {
+ *out_resource_op_kind = absl::nullopt;
+ }
+
+ return Status::OK();
+}
+
+// Returns true if a control or data dependence from a TensorFlow operation of
+// resource op kind `from` to a TensorFlow operation of resource op kind `to`
+// can be represented by an XLA cluster and needs no special handling around
+// auto-jit.
+bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
+ // XLA clusters forces all reads to happen before all writes, which means the
+ // kinds of edges it can faithfully represent are: Read->Write, Read->Modify,
+ // Modify->Write, Read->Read, Write->Write.
+ //
+ // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+ // dependencies.
+ return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite;
+}
+
+using ResourceOp = std::pair;
+
+string ResourceOpToString(const ResourceOp& resource_op) {
+ return strings::StrCat(
+ resource_op.first, ": ",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
+}
+
+// A copy-on-write set used to store the set of ResourceOps reaching a node in a
+// TensorFlow graph.
+//
+// TODO(sanjoy): It may be useful to pull this out into its own header at some
+// point.
+class ResourceOpSet {
+ private:
+ using Impl = gtl::FlatSet;
+
+ public:
+ ResourceOpSet() = default;
+
+ // Adds all ResourceOp s in `other` to this set.
+ void Add(const ResourceOpSet& other) {
+ CHECK(!frozen_);
+ if (other.impl_ == impl_) {
+ other.frozen_ = true;
+ return;
+ }
+
+ if (!impl_) {
+ other.frozen_ = true;
+ impl_ = other.impl_;
+ return;
+ }
+
+ for (ResourceOp resource_op : other) {
+ Add(resource_op);
+ }
+ }
+
+ void Add(const ResourceOp& resource_op) {
+ CHECK(!frozen_);
+ if (!IsCopy() && Contains(resource_op)) {
+ // We can avoid the copy if the item we want to insert already exists.
+ return;
+ }
+
+ EnsureIsCopied();
+ impl_->insert(resource_op);
+ }
+
+ Impl::const_iterator begin() const {
+ return impl_ ? impl_->begin() : GetEmptyImpl()->begin();
+ }
+
+ Impl::const_iterator end() const {
+ return impl_ ? impl_->end() : GetEmptyImpl()->end();
+ }
+
+ bool Contains(const ResourceOp& resource_op) const {
+ return impl_ != nullptr && impl_->count(resource_op);
+ }
+
+ private:
+ bool IsCopy() const { return storage_ != nullptr; }
+
+ void EnsureIsCopied() {
+ if (storage_ == nullptr) {
+ storage_ = absl::make_unique();
+ for (ResourceOp op : *this) {
+ storage_->insert(op);
+ }
+ impl_ = storage_.get();
+ }
+ }
+
+ static Impl* GetEmptyImpl() {
+ static Impl* empty_impl = new Impl;
+ return empty_impl;
+ }
+
+ Impl* impl_ = nullptr;
+ std::unique_ptr storage_;
+
+ // frozen_ is true if there is another set pointing to this set's impl_. We
+ // can no longer add elements to this set in that case since the sets pointing
+ // to this set expect the contents of this set to be stable.
+ mutable bool frozen_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet);
+};
+
+string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
+ std::vector elements_debug_string;
+ std::transform(resource_op_set.begin(), resource_op_set.end(),
+ std::back_inserter(elements_debug_string), ResourceOpToString);
+ return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
+}
+
+string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
+ return strings::StrCat(
+ "[", n.name(), ": ", n.type_string(), "(",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
+}
+} // namespace
+
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function& resource_ops_to_ignore,
+ std::vector>* result) {
+ CHECK(result->empty());
+
+ std::vector rpo;
+ GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ auto resource_op_set_for_node =
+ absl::make_unique(g.num_node_ids());
+
+ const bool vlog = VLOG_IS_ON(2);
+
+ for (Node* n : rpo) {
+ absl::optional op_kind;
+ TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
+ *n, flib_def, resource_ops_to_ignore, &op_kind));
+
+ ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
+
+ // Merge the reaching resource operations for all the incoming edges to
+ // create the set of all possible resource ops reaching `n`.
+ for (const Edge* e : n->in_edges()) {
+ if (n->IsMerge() && e->src()->IsNextIteration()) {
+ // Ignore back-edges (see file comment).
+ continue;
+ }
+
+ const ResourceOpSet& incoming_op_set =
+ resource_op_set_for_node[e->src()->id()];
+ resource_op_set->Add(incoming_op_set);
+ }
+
+ // Add to the "incompatible resource ops" set if necessary.
+ if (op_kind) {
+ for (ResourceOp incoming_op : *resource_op_set) {
+ if (IsEdgeSafe(incoming_op.second, *op_kind)) {
+ continue;
+ }
+
+ if (vlog) {
+ VLOG(2) << "Unsafe edge: "
+ << NodeToString(*g.FindNodeId(incoming_op.first),
+ incoming_op.second)
+ << " -> " << NodeToString(*n, *op_kind);
+ }
+ result->push_back({incoming_op.first, n->id()});
+ }
+
+ resource_op_set->Add({n->id(), *op_kind});
+ }
+
+ if (vlog) {
+ VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set);
+ }
+ }
+
+ std::sort(result->begin(), result->end());
+ CHECK(std::unique(result->begin(), result->end()) == result->end());
+
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
new file mode 100644
index 0000000000000000000000000000000000000000..ae8cfeecad9b9cd631db3e9865bb3c3ff28a2e48
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// returns the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// The restrictions are not transitive: it is fine to put A and C in the same
+// cluster even if the returned set contains (A,B) and (B,C).
+//
+// In other words, if these pairs are seen as edges in an undirected graph of
+// the nodes in `g` then auto-clustering is at least as constrained as the graph
+// coloring problem on this graph.
+//
+//
+// For instance if we auto-cluster all operations in this TensorFlow graph:
+//
+// ReadVariablepOp0 -> ReadVariableOp1
+// |
+// v
+// AssignVariableOp0 -> AssignVariableOp1
+//
+// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the
+// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for
+// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads
+// all the resource variables when the cluster starts executing without any
+// particular ordering between them; same holds for the AssignVariableOp0 ->
+// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will
+// be respected by XlaLaunchOp though because all reads happen before all
+// writes.
+//
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// back-edges (i.e. the edges from NextIteration to Merge).
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// functional control flow nodes containing resource operations.
+//
+// If `resource_ops_to_ignore` is set then nodes for which it returns true are
+// ignored (we pretend these nodes are not resource operations).
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function& resource_ops_to_ignore,
+ std::vector>* result);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e54b547abcfea698fe79e81dce547ea7858ff829
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
@@ -0,0 +1,540 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+Node* MakeModify(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
+ ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
+ var_handle, value_to_write);
+ return assign_add_op.operation.node();
+}
+
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
+
+Status ComputeIncompatiblePairs(Graph* g,
+ std::vector>* result) {
+ FixupSourceAndSinkEdges(g);
+ return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
+ result);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ MakeRead(root, "R");
+ MakeWrite(root, "W");
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair modify_read_pair = {modify->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair modify_write_pair = {modify->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair write_modify_pair = {write->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, modify);
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 2);
+ std::pair modify_write_pair = {modify->id(), write->id()};
+ std::pair read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair write_modify_pair = {write->id(), modify->id()};
+ std::pair modify_read_pair = {modify->id(), read->id()};
+ std::pair write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair write_modify_pair = {write->id(), modify->id()};
+ std::pair write_read_pair = {write->id(), read->id()};
+ std::pair read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
+ Status* status) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ return graph->AddNode(call_node, status);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, read);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair call_read_edge = {call->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(read, call);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair read_call_edge = {read->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, write);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair call_write_edge = {call->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_write_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(write, call);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair write_call_edge = {write->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(symbolic_gradient, read);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair symbolic_gradient_read_edge = {symbolic_gradient->id(),
+ read->id()};
+ EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(write, symbolic_gradient);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair write_symbolic_gradient_edge = {write->id(),
+ symbolic_gradient->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 5);
+ std::pair write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair write_1_read_1_pair = {write_1->id(), read_1->id()};
+ std::pair write_0_write_1_pair = {write_0->id(), write_1->id()};
+ std::pair read_0_read_1_pair = {read_0->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+ root.graph()->AddControlEdge(write_1, read_1);
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, Loop) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
+ Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
+ Output enter_value =
+ ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
+ ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
+ ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
+ Output next_iteration =
+ ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
+ TF_ASSERT_OK(
+ root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
+
+ Node* write = MakeWrite(root, "W");
+ Node* read = MakeRead(root, "R");
+
+ root.graph()->AddControlEdge(iv.output.node(), write);
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, next_iteration.node());
+
+ std::vector> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+
+ std::pair write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
+ return arg_def.type() == DT_RESOURCE;
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 38adacd93bc43ef17734b909af862e063574e986..4f2fabd658330b8ab182e13e02ed0bca41641e46 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -207,4 +208,27 @@ bool HasResourceInputOrOutput(const Node& node) {
void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
+
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function& resource_ops_to_ignore,
+ GraphCycles* cycles) {
+ std::vector> unsafe_deps;
+ TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
+ *graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
+
+ // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
+ // operations that interact with resource variables, must not be put in the
+ // same cluster. We enforce this constraint by creating a phantom node, X,
+ // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
+ // and Q together since that would create a cycle with X.
+
+ for (std::pair unsafe_dep : unsafe_deps) {
+ int phantom_node_id = cycles->NewNode();
+ CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
+ CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 662a53d89eb37128be54abc95e1748c6d5f9081f..b0439a63ca6476b6b1d63e65308712270381dd9f 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -55,6 +55,13 @@ void RemoveFromXlaCluster(NodeDef* node_def);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
+// Adds edges to `cycles` to prevent clustering resource operations that cannot
+// be legally clustered.
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function& resource_ops_to_ignore,
+ GraphCycles* cycles);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc
index 2cb351e1ecdb4523a8652886af156540e4736b18..65bbf3efe85ba30f44531ff6d54b041786dca0a5 100644
--- a/tensorflow/compiler/jit/xla_cluster_util_test.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 7140d47a9421ec73d0144e855b490f89569e6ae9..ef6b0e67d3c4007f86dc7eef89cacb4cea98fc15 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile(
const std::map& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
return CompileImpl(options, function, constant_args, variable_args, ctx,
compilation_result, executable, compile_options, false);
}
@@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp(
const std::map& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
const NodeDef& def = ctx->op_kernel().def();
NameAttrList name;
name.set_name(def.op());
@@ -256,7 +256,7 @@ Status XlaCompilationCache::CompileImpl(
const std::map& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op) {
CHECK_NE(executable, nullptr);
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
@@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl(
entry->compiled = true;
if (compile_single_op) {
- entry->compilation_status = compiler.CompileSingleOp(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- signature.name, ctx, args, &entry->compilation_result);
+ entry->compilation_status =
+ compiler.CompileSingleOp(compile_options, signature.name, ctx, args,
+ &entry->compilation_result);
} else {
entry->compilation_status = compiler.CompileFunction(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- function, args, &entry->compilation_result);
+ compile_options, function, args, &entry->compilation_result);
}
TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr);
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index fc5f008f4f52c32d97e680784082d0e7bcb7d8eb..10ad87e38cc4d614e869782329f84351bc3b1f0b 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
@@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase {
const std::map& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
xla::LocalClient* client() const { return client_; }
const DeviceType& device_type() const { return device_type_; }
@@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index dd84fb34c171f8d2174444ddd3b3b476e7142718..3ba48e8c318f84a4691fb74434bc009fdd0d81bf 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
- result, executable, &compile_options);
+ result, executable, compile_options);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 70e6d0be0f2cffe98fd77fddac5866789c411a51..50c902fdfc06e9fb2cbcd9dd44640a7d40d0fe81 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -365,11 +365,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
- // When Xprof profiling is off (which is the default), constructing the
- // activity is simple enough that its overhead is negligible.
- tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
- op_kernel->IsExpensive());
- op_kernel->Compute(context);
+ TracingDevice::Compute(op_kernel, context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 2027ec7737895557a46df38fe6d0ddb372e3bb67..ee07c5c9643ef1119b9077326c1cf7c83930e90c 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -184,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- if (status.ok()) {
- xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback
- // to avoid a deadlock. If done() is the callback that ends an
- // Executor's run, the Executor may call XlaDevice::Sync() inside the
- // callback. This deadlocks, because XlaDevice::Sync() waits for all
- // stream activity to complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
- return;
- }
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
@@ -208,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
- xla_tensor->set_host_tensor(*cpu_tensor);
-
+ if (status.ok()) {
+ xla_tensor->set_host_tensor(*cpu_tensor);
+ }
done(status);
}
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 4b499b161371ecece14447b29fbf809b6e8857db..07cfab615157650aea0e15cdafa8c9b0925f9e5f 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -41,8 +41,8 @@ static bool IsShapeConsumerOp(const Node& node) {
}
// Returns true if the op can be decomposed into XLA ops for which
-// there are fusable elemental implementations.
-bool IsXlaFusable(const NodeDef& node) {
+// there are fusible elemental implementations.
+static bool IsXlaFusible(const NodeDef& node) {
static const std::unordered_set* elementwise_ops =
new std::unordered_set(
{// tf2xla/kernels/aggregate_ops.cc
@@ -176,9 +176,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
- // Assume all fusable ops are registered.
+ // Assume all fusible ops are registered.
// TODO(hpucha): Check for registration if possible.
- if (!IsXlaFusable(node->def())) {
+ if (!IsXlaFusible(node->def())) {
continue;
}
@@ -208,6 +208,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
index 5736760a878dc857a8558093054d0adc0f727398..68e19c8a135735a79fcabf121e619157fa22b4d8 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@@ -71,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) {
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
-TEST_F(XlaFusionOptimizerTest, FusableOps) {
+TEST_F(XlaFusionOptimizerTest, FusibleOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
@@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
+TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output var_handle =
+ ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
+ Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
+ Output begin = ops::Const(root.WithOpName("begin"), 0);
+ Output end = ops::Const(root.WithOpName("end"), 1);
+ Output strides = ops::Const(root.WithOpName("strides"), 1);
+ ops::ResourceStridedSliceAssign assign_1(
+ root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
+ ops::ResourceStridedSliceAssign assign_2(
+ root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
+ root.graph()->AddControlEdge(assign_1.operation.node(),
+ assign_2.operation.node());
+ grappler::GrapplerItem item;
+ root.graph()->ToGraphDef(&item.graph);
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 2ffce9298d99e1e136e15e9a4b0e3f5b26121bd5..affeab4a8c43b63ac0e2b8ef40de5223ce39d410 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
- VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
-
- se::DeviceMemoryBase buffer = output.buffer({output_num});
- if (allocate_xla_tensors_) {
- Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
- XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- if (xla_tensor) {
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
- if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ const DataType& type = kernel->outputs[i].type;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
+ << DataTypeString(type);
+ if (type == DT_RESOURCE) {
+ ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
+ } else {
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
+ if (allocate_xla_tensors_) {
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ xla_tensor->SetDefinedOn(stream, definition_event);
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
- // xla_tensor wasn't valid, which must mean this is a zero-element
- // tensor.
- CHECK_EQ(output_tensor->TotalBytes(), 0);
+ Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+ ctx->expected_output_dtype(i), shape, buffer, allocator);
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
+ ctx->set_output(i, output_tensor);
}
- } else {
- Tensor output_tensor = XlaTensorBuffer::MakeTensor(
- ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(xla::OwningDeviceMemory(), {output_num});
- ctx->set_output(i, output_tensor);
+ ++output_num;
}
- ++output_num;
}
if (VLOG_IS_ON(3)) {
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 47311d2630175ee8c7eb52d587f138128bcad3df..cf02926e0675e94381462f9579c36909c3bf7de9 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -72,7 +72,7 @@ py_test(
tf_xla_py_test(
name = "adadelta_test",
- size = "medium",
+ size = "large",
srcs = ["adadelta_test.py"],
deps = [
":xla_test",
@@ -728,6 +728,7 @@ tf_xla_py_test(
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -1190,3 +1191,19 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "xla_ops_test",
+ size = "small",
+ srcs = ["xla_ops_test.py"],
+ disabled_backends = ["cpu_ondemand"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py
index 3e3c09c66e72c4de141b64cea3c4693fabb7b2a2..b7b7fda293b69d6f0cec61d0d234277636a3670d 100644
--- a/tensorflow/compiler/tests/adadelta_test.py
+++ b/tensorflow/compiler/tests/adadelta_test.py
@@ -33,7 +33,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
num_updates = 4 # number of ADADELTA steps to perform
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
var0_init = [1.0, 2.0]
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
index dc1625793aa44b96d3b96e175237caf96e7d7e74..69fb3ec2964a09508e612515b9e291fc14121d68 100644
--- a/tensorflow/compiler/tests/adagrad_da_test.py
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithoutRegularizationBasic1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1_L2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index d775850a80e9f83f7b2c9f1cf8997dd50e229635..ab69319c59fb07e7ce56c3c287a50a6290effdfd 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -57,7 +57,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -83,7 +83,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
index c4fdbc5974319db9243eb2c323746cbaaea795f6..3ed1d41b7121f44dd7470f61180f7a7055369174 100644
--- a/tensorflow/compiler/tests/adamax_test.py
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for i, dtype in enumerate(self.float_types):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@@ -100,7 +100,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
index 9ec5a964cbb4dd98d2ef2d0b684872292118800f..1bc07ace23ccdc83103abe71ee11b72994c75a6d 100644
--- a/tensorflow/compiler/tests/addsign_test.py
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase):
alpha=1.0,
beta=0.9):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 9d3a889b1f54c813e881bb03b5275f809af1b3c8..4155342787fbbdeaf5c5958c44d007b1ea0660ed 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase):
op_input: numpy input array to use as input to 'op'.
expected: numpy array representing the expected output of 'op'.
"""
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(op_input.dtype), op_input.shape, name="a")
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 5b7001b5a463ae0bd4e8f07032256717aab70d49..17280e445b329d1541aaed78ec106f8f282cbc74 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -36,7 +36,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
"""Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
@@ -1010,7 +1010,38 @@ class BinaryOpsTest(xla_test.XLATestCase):
[7, 7, 7, 7, 7, 7]],
dtype=dtype))
- def testMirrorPad(self):
+ def testSymmetricMirrorPad(self):
+ mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
+ for dtype in self.numeric_types:
+ self._testBinary(
+ mirror_pad,
+ np.array(
+ [
+ [1, 2, 3], #
+ [4, 5, 6], #
+ ],
+ dtype=dtype),
+ np.array([[
+ 2,
+ 2,
+ ], [3, 3]], dtype=np.int32),
+ expected=np.array(
+ [
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ ],
+ dtype=dtype))
+ self._testBinary(
+ mirror_pad,
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[0, 0], [0, 0]], dtype=np.int32),
+ expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
+
+ def testReflectMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types:
self._testBinary(
@@ -1372,5 +1403,40 @@ class BinaryOpsTest(xla_test.XLATestCase):
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
+ def testBroadcastTo(self):
+ for dtype in self.all_types:
+ x = np.random.randint(0, high=100, size=[2, 3])
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([2, 3], dtype=np.int32),
+ expected=x)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([6, 6], dtype=np.int32),
+ expected=np.tile(x, [3, 2]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 4, 3], dtype=np.int32),
+ expected=np.tile(x, [7, 2, 1]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 0, 3], dtype=np.int32),
+ expected=np.zeros([7, 0, 3], dtype=dtype))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 1, 2, 9], dtype=np.int32),
+ expected=np.tile(x, [7, 1, 1, 3]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ np.zeros([2, 0], dtype=dtype),
+ np.array([4, 0], dtype=np.int32),
+ expected=np.zeros([4, 0], dtype=dtype))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py
index ef4d5f6322b7ae79b051795b5af7e6f7f1e55550..5c24db539bce5df701d8229290ddb4c20997d40a 100644
--- a/tensorflow/compiler/tests/bucketize_op_test.py
+++ b/tensorflow/compiler/tests/bucketize_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class BucketizationOpTest(xla_test.XLATestCase):
def testInt(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]}))
def testFloat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.])
@@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]}))
def test2DInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
def testInvalidBoundariesOrder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11])
@@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0]})
def testBoundariesNotList(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Expected list.*"):
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9..a57d1dc81ea2c9c188b0a3005904738aa8156bf3 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -56,7 +56,7 @@ class CategoricalTest(xla_test.XLATestCase):
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
@@ -79,7 +79,7 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = rng(dtype, output_dtype)
@@ -107,7 +107,7 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index ed532db0ee5553a275192e6cc3ebf394075fa0e1..d1896a50f7037f2972cba8a4fa16cc1e2cd4fe3e 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase):
def _verifyCholesky(self, x, atol=1e-6):
# Verify that LL^T == x.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(
dtypes.as_dtype(x.dtype), shape=x.shape)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
index e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b..88bd58b2da6b2892f898ad10f3467d8ce39d6388 100644
--- a/tensorflow/compiler/tests/clustering_test.py
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
val2 = np.array([5, 6, 7, 8], dtype=np.float32)
expected = val1 + val2
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
@@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1]).astype(np.float32)
val2 = np.array([5, 6, 7, 8]).astype(np.float32)
expected = val1 + val2
- with self.test_session():
+ with self.cached_session():
with ops.device(CPU_DEVICE):
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
@@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase):
# where x and z are placed on the CPU and y and w are placed on the XLA
# device. If y and w are clustered for compilation, then the graph will
# deadlock since the clustered graph will contain a self-loop.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device(CPU_DEVICE):
x = array_ops.placeholder(dtypes.float32, [2])
with self.test_scope():
@@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase):
self.assertAllClose(result, [12., 2.], rtol=1e-3)
def testHostMemory(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.int32)
with self.test_scope():
y = x + 1
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index d9ad4281477e87f79f2ecb52989ae86a5030d0cc..37e5318bb54c5d8ecdedc7bb346e89765f2adf35 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
class ConcatTest(xla_test.XLATestCase):
def testHStack(self):
- with self.test_session():
+ with self.cached_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
@@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[4:, :], params[p2])
def testVStack(self):
- with self.test_session():
+ with self.cached_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
@@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[:, 4:], params[p2])
def testInt32(self):
- with self.test_session():
+ with self.cached_session():
p1 = np.random.rand(2, 3).astype("i")
p2 = np.random.rand(2, 3).astype("i")
x1 = constant_op.constant(p1)
@@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase):
dtype_feed = dtypes.float32
else:
dtype_feed = dtype
- with self.test_session():
+ with self.cached_session():
p = []
for i in np.arange(num_tensors):
input_shape = shape
@@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testRandom(dtypes.int32)
def _testGradientsSimple(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsSimple()
def _testGradientsFirstDim(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsFirstDim()
def _testGradientsLastDim(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase):
# Random dim to concat on
concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
@@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase):
def testConcatTuple(self):
c1 = np.random.rand(4, 4).astype(np.float32)
c2 = np.random.rand(4, 4).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
@@ -295,7 +295,7 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
@@ -309,7 +309,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
class PackTest(xla_test.XLATestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -319,7 +319,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
@@ -329,7 +329,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index f9db103f6d0f9ea0e393a0971593552ec5c14079..af00ff287d43a8542b5a3d14eedc00c3d7aef1b7 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
with self.test_scope():
@@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope():
@@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py
index 31ee41f04f27d387415e9fa2c4fa70b33cab7b04..33fd983b5485e503c2fcc96db2dfdecfc41e309f 100644
--- a/tensorflow/compiler/tests/conv3d_test.py
+++ b/tensorflow/compiler/tests/conv3d_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest
class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
def testGradient(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
@@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
class Conv3DTransposeTest(xla_test.XLATestCase):
def testConv3DTransposeSingleStride(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 1, 1, 1, 1]
# Input, output: [batch, depth, height, width, channel]
@@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeSame(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeValid(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 865f60ccab46ec6829e49409508303052944e13b..04f3b3ef4905984b0432a536c3b1c275738ede17 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -86,7 +86,7 @@ class DenseLayerTest(test.TestCase):
XlaLaunch op by XLA.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
@@ -113,7 +113,7 @@ class DenseLayerTest(test.TestCase):
cluster, causing dense layer to be split into TWO XlaLaunch ops.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 98dc73e189f99b7b811487756659d89dacb97d8a..6ef8a68ca5d35d3d2f78f0cb491e7bb98ff97ac9 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=data_type).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=data_type).reshape(filter_in_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if data_type == np.float32:
tolerance = 1e-4
else:
@@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=np.float32).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=np.float32).reshape(filter_in_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32)
with self.test_scope():
@@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
- with self.test_session():
+ with self.cached_session():
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
@@ -356,7 +356,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
- with self.test_session():
+ with self.cached_session():
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
index 154e36b10e6da409606ae6022aaf53e34c8e37cc..5f01e128f0b0fa725d99b00ba3406bd50a1b8962 100644
--- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py
+++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index edd78153b56bb5bf1c268936fb82a60581389733..50b04daa6b9f4159a3c4bdeecaf900a5b35a833c 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest
class DynamicStitchTest(xla_test.XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
index_placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
]
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 3d21fb5864c22a6f449c54d03abc0f234e28dab1..63cee550fde9d9d4314b1541fba191df776a4da2 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -101,7 +101,7 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product)
# Run some ops graphly
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
@@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f(v)
self.assertEqual(2.0, var.numpy())
+ def testReturnResourceHandle(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
+
+ def f(v):
+ return v.handle
+
+ f = function.defun(f)
+ handle = f(v)
+ self.assertAllEqual(v.numpy(),
+ resource_variable_ops.read_variable_op(
+ handle, dtypes.float32).numpy())
+
+ def testReturnMultipleResourceHandles(self):
+ with self.test_scope():
+ v1 = resource_variable_ops.ResourceVariable(1.25)
+ v2 = resource_variable_ops.ResourceVariable(2.0)
+
+ def f(v):
+ return v.handle, 3.0 * v, v2.handle, v + v2
+
+ f = function.defun(f)
+ v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
+ self.assertAllEqual(v1.numpy(),
+ resource_variable_ops.read_variable_op(
+ v1_handle, dtypes.float32).numpy())
+ self.assertEqual(3.75, v1_times_3.numpy())
+ self.assertAllEqual(v2.numpy(),
+ resource_variable_ops.read_variable_op(
+ v2_handle, dtypes.float32).numpy())
+ self.assertEqual(3.25, variable_sum.numpy())
+
def testAllArgumentKinds(self):
"""Test a complex function that takes different argument kinds.
@@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase):
y = two_x_plus_1(x)
self.assertAllEqual([5, 7, 9], y.numpy())
+ def testNestedDefunWithVariable(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ y = f(x)
+
+ self.assertEqual(75, y.numpy())
+
+ def testNestedDefunInGradientTape(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ y = f(x)
+ dy = tape.gradient(y, v0)
+
+ self.assertEqual(75, y.numpy())
+ self.assertEqual(30, dy.numpy())
+
+ def testNestedDefunInGradientTapeDifferentVars(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+ v1 = resource_variable_ops.ResourceVariable(3.0)
+
+ @function.defun
+ def g(x):
+ x = v1 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape(persistent=True) as tape:
+ y = f(x)
+ dy_v0 = tape.gradient(y, v0)
+ dy_v1 = tape.gradient(y, v1)
+
+ self.assertEqual(45, y.numpy())
+ self.assertEqual(9, dy_v0.numpy())
+ self.assertEqual(15, dy_v1.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py
index 5529fdbb090315e1d7f47589777d8a538c90db2b..37061e91d161db352b388a965eb72c9c32d3d752 100644
--- a/tensorflow/compiler/tests/extract_image_patches_op_test.py
+++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py
@@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase):
strides = [1] + strides + [1]
rates = [1] + rates + [1]
- with self.test_session():
+ with self.cached_session():
image_placeholder = array_ops.placeholder(dtypes.float32)
with self.test_scope():
out_tensor = array_ops.extract_image_patches(
diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py
index c48ab178bf53558084fb500b2811c6f0b77a7943..2178c4455609550226c89ceb185837768be1f622 100644
--- a/tensorflow/compiler/tests/fake_quant_ops_test.py
+++ b/tensorflow/compiler/tests/fake_quant_ops_test.py
@@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
@@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
@@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
@@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
expected_backprops_wrt_min = 1.0 + 2.0
expected_backprops_wrt_max = 10.0 + 11.0
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index c64ea249ecb97991952a960a6d16e1bb3be35b17..b3e13fbaa6b33bdaa1be123be558059e96de282e 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -71,7 +71,7 @@ class FFTTest(xla_test.XLATestCase):
data = np.reshape(data.astype(np.float32).view(np.complex64), shape)
data = to_32bit(complex_to_input(data))
expected = to_32bit(input_to_expected(data))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
@@ -93,7 +93,7 @@ class FFTTest(xla_test.XLATestCase):
data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
expected = np.swapaxes(expected, -1, -2)
expected *= window.sum() # scipy divides by window sum
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
index 0f64cc87cde77fbbef6c4e570879e992bc34bafa..8c7edfd277c992c35a81dd5f261256a86352254e 100644
--- a/tensorflow/compiler/tests/fifo_queue_test.py
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -31,13 +31,13 @@ from tensorflow.python.platform import test
class FIFOQueueTest(xla_test.XLATestCase):
def testEnqueue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
@@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual(1, q.size().eval())
def testMultipleDequeues(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue([1]))
self.evaluate(q.enqueue([2]))
@@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
def testQueuesDontShare(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue(1))
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
@@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})
def testParallelEnqueue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b..7ca50b02d9bf3203cbd460c8de13a16defd974a3 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -112,7 +112,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -146,7 +146,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -174,7 +174,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -202,7 +202,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL1_L2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -236,7 +236,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
weights will tend to have smaller magnitudes with this parameter set.
"""
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -273,9 +273,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivAdagradwithoutRegularization(self):
steps = 5
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype)
self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2)
@@ -284,9 +284,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivGradientDescentwithoutRegularization(self):
steps = 5
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.equivGradientDescentTest_GradientDescentPart(
steps, dtype)
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 04fba444460e714ce96205361ac02ed492206b04..b1891b918c6584abce9da382088ed0037f5319fb 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase):
def testCompileTimeConstantsInDefun(self):
"""Tests that XLA handles compile-time constants in defuns."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.int32, dtypes.int32)
def Foo(a, c, d):
@@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = aval + bval * 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtypes.float32, name="a")
b = array_ops.placeholder(dtypes.float32, name="b")
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 132e42ac7a28d0769b0de12ea0cee6eae752b245..8c018cccb83a05babb0b7f73b80b4f9de7267c98 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -83,7 +83,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
# To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format)
@@ -126,7 +126,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
# To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format)
@@ -210,7 +210,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats(
@@ -260,7 +260,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
var_val = np.random.random_sample(scale_shape).astype(np.float32)
data_format_src = "NHWC"
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats(
diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py
index 23b0aed34fb460f50c241e5a920cb4f6f613b947..7161f4ab339b6f4069dd2b02ddbc6a89973e0074 100644
--- a/tensorflow/compiler/tests/gather_nd_op_test.py
+++ b/tensorflow/compiler/tests/gather_nd_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class GatherNdTest(xla_test.XLATestCase):
def _runGather(self, params, indices):
- with self.test_session():
+ with self.cached_session():
paramsp = array_ops.placeholder(params.dtype)
indicesp = array_ops.placeholder(indices.dtype)
with self.test_scope():
@@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase):
np.array([[4], [4], [0]], np.int32)))
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
- with self.test_session():
+ with self.cached_session():
params = np.ones((3, 3), dtype=np.float32)
indices_empty = np.empty((0, 2), dtype=np.int32)
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3..089d95daab7e502b4ba13796fadc2ba3f209759b 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase):
return data
def testScalar1D(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in self.all_tf_types:
for indices in 4, [4], [1, 2, 2, 4, 5]:
@@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in self.all_tf_types:
@@ -69,7 +69,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in self.all_tf_types:
@@ -87,7 +87,7 @@ class GatherTest(xla_test.XLATestCase):
if np.int64 not in self.int_types:
return
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
# The indices must be in bounds for any axis.
@@ -114,7 +114,7 @@ class GatherTest(xla_test.XLATestCase):
for axis in 0, 1, 2, 3, -1, -2:
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
tf_params = array_ops.placeholder(dtype=dtype)
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
@@ -123,7 +123,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
- with self.test_session():
+ with self.cached_session():
for dtype in self.numeric_tf_types:
params = array_ops.placeholder(dtype=dtype)
indices = array_ops.placeholder(dtype=np.int32)
@@ -137,7 +137,7 @@ class GatherTest(xla_test.XLATestCase):
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
def testGatherPrecision(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
[0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
indices = np.array([1, 2, 3, 1])
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index bf986ade06b11358552ee92df3169f965ce3f534..6fe5a66e0e6717ec738dded9196eef6ba1e2114d 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -54,7 +54,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
inp = GenerateNumpyRandomRGB(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch0 = array_ops.placeholder(nptype, shape=shape)
with self.test_scope():
batch1 = image_ops.rgb_to_hsv(batch0)
@@ -78,7 +78,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
for nptype in self.float_types:
rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255.
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(nptype)
with self.test_scope():
hsv = image_ops.rgb_to_hsv(placeholder)
@@ -97,7 +97,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
for r, g, b in rgb_flat
])
hsv_np = hsv_np.reshape(4, 4, 4, 3)
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(nptype)
with self.test_scope():
hsv_op = image_ops.rgb_to_hsv(placeholder)
@@ -108,7 +108,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
class AdjustContrastTest(xla_test.XLATestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_np.shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -146,7 +146,7 @@ class AdjustContrastTest(xla_test.XLATestCase):
return y_np
def _adjustContrastTf(self, x_np, contrast_factor):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(np.float32)
with self.test_scope():
y = image_ops.adjust_contrast(x, contrast_factor)
@@ -180,7 +180,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -198,7 +198,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -216,7 +216,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -244,7 +244,7 @@ class AdjustHueTest(xla_test.XLATestCase):
return y_v.reshape(x_np.shape)
def _adjustHueTf(self, x_np, delta_h):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32)
with self.test_scope():
y = gen_image_ops.adjust_hue(x, delta_h)
@@ -324,7 +324,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np})
@@ -339,7 +339,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np})
@@ -378,7 +378,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
"gb_same",
"rgb_same",
]
- with self.test_session():
+ with self.cached_session():
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
@@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase):
image_np,
target_shape,
expected=None,
- large_tolerance=False):
+ large_tolerance=False,
+ align_corners=True):
if expected is None:
self.fail("expected must be specified")
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
image = array_ops.placeholder(image_np.dtype)
resized = gen_image_ops.resize_bilinear(
- image, target_shape, align_corners=True)
+ image, target_shape, align_corners=align_corners)
out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]})
if large_tolerance:
self.assertAllClose(
@@ -433,7 +434,7 @@ class ResizeBilinearTest(xla_test.XLATestCase):
self.fail("input_shape must be specified")
if expected is None:
self.fail("expected must be specified")
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
dtype = dtype or np.float32
grads = array_ops.placeholder(np.float32)
resized = gen_image_ops.resize_bilinear_grad(
@@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase):
dtype=np.float32)),
large_tolerance=True)
+ def testNonAlignCorners3x2To6x4(self):
+ input_data = [[64, 32], [32, 64], [50, 100]]
+ expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0],
+ [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0],
+ [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [6, 4],
+ expected=np.array(expected_data, dtype=np.float32),
+ align_corners=False)
+
+ def testNonAlignCorners6x4To3x2(self):
+ input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127],
+ [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]]
+ expected_data = [[127, 64], [64, 127], [50, 100]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [3, 2],
+ expected=np.array(expected_data, dtype=dtype),
+ align_corners=False)
+
class NonMaxSuppressionTest(xla_test.XLATestCase):
@@ -596,7 +618,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@@ -639,7 +661,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@@ -686,7 +708,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.4, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
index 45a04f0cf56e88946b946bedacb25ce6da3121b4..58622114e4f552fb71db9b040a39b57d7da0037c 100644
--- a/tensorflow/compiler/tests/listdiff_op_test.py
+++ b/tensorflow/compiler/tests/listdiff_op_test.py
@@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
index 253b45902fba2df64e5234f135b373cd2a0a7e2a..c6ad67993e8bc196a74c9a328df8c9200c92c575 100644
--- a/tensorflow/compiler/tests/lrn_ops_test.py
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase):
return output
def _RunAndVerify(self, dtype):
- with self.test_session():
+ with self.cached_session():
# random shape
shape = np.random.randint(1, 16, size=4)
# Make depth at least 2 to make it meaningful
@@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase):
alpha = 1.0 * np.random.rand()
beta = 1.0 * np.random.rand()
- with self.test_session():
+ with self.cached_session():
in_image = constant_op.constant(in_image_vals, shape=shape)
out_image = constant_op.constant(out_image_vals, shape=shape)
out_grads = constant_op.constant(out_grads_vals, shape=shape)
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
index 31093c65713df55390c3130b8654fdcb10fbc133..265c0b6d1412de7be3a5bf5e79129cb330ceb162 100644
--- a/tensorflow/compiler/tests/lstm_test.py
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -73,7 +73,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
pad_scalar):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_inputs = 1
num_nodes = 1
@@ -156,7 +156,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
pad_scalar):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_inputs = 1
num_nodes = 1
seq_length = 3
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 0d9f99f8a6803ecae5f9233518a1768109161ac0..9222db4b7ebf020c8cee1c0af81e05129fb33c4d 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class MatrixBandPartTest(xla_test.XLATestCase):
def _testMatrixBandPart(self, dtype, shape):
- with self.test_session():
+ with self.cached_session():
batch_shape = shape[:-2]
mat = np.ones(shape).astype(dtype)
batch_mat = np.tile(mat, batch_shape + [1, 1])
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 2bb8a97bdaf5836a05501ab9754433e29ae34675..94cd3eeb3179da9b920ea9f03216d602b042a639 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
clean_a = np.tril(a) if lower else np.triu(a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder_a = MakePlaceholder(a)
placeholder_ca = MakePlaceholder(clean_a)
placeholder_b = MakePlaceholder(b)
diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py
index c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078..f77521a7c49dba39849869ddceb7c0e885147722 100644
--- a/tensorflow/compiler/tests/momentum_test.py
+++ b/tensorflow/compiler/tests/momentum_test.py
@@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -95,7 +95,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testNesterovMomentum(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype)
var0_np = np.array([0.1, 0.2], dtype=dtype)
@@ -120,7 +120,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index da08225e9fc0d5a8ec21ee9961c4758fa38628b4..a1c07fce732d3b91a7c0550545a03fdab67644d3 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase):
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
def testOneHot(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
op = array_ops.one_hot(indices,
np.int32(4),
@@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase):
self.assertAllEqual(output, expected)
def testSplitV(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
output = session.run(
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index 2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7..f985c5d2d96e06fc0117f3935d61b19c9e8562b1 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest
class NullaryOpsTest(xla_test.XLATestCase):
def _testNullary(self, op, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
output = op()
result = session.run(output)
self.assertAllClose(result, expected, rtol=1e-3)
def testNoOp(self):
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
output = control_flow_ops.no_op()
# This should not crash.
diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py
index d68d32057a367776d5b70d5ac21d5618297c605d..7635f89249b7b71e5353e0b7cb1cea5c1f7bca1d 100644
--- a/tensorflow/compiler/tests/oom_test.py
+++ b/tensorflow/compiler/tests/oom_test.py
@@ -46,7 +46,7 @@ class OutOfMemoryTest(xla_test.XLATestCase):
def test_loop():
size = int(2e8)
while True:
- with self.test_session():
+ with self.cached_session():
# Force the compiled code to not be constant by feeding in a
# parameter.
p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
index a75d99189b5b673261c9e48f1c5998ea0c575594..77bb839409f0c323ff6ed2c8d6bd105d3003b398 100644
--- a/tensorflow/compiler/tests/placeholder_test.py
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest
class PlaceholderTest(xla_test.XLATestCase):
def test_placeholder_with_default_default(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
@@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase):
self.assertEqual(8.0, sess.run(out))
def test_placeholder_with_default_fed(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py
index 17f860db61aeda98326a6820771d67ee948b6dda..b6cdd38345b9a9f6b03e8799587e3f6ffe07b407 100644
--- a/tensorflow/compiler/tests/pooling_ops_3d_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py
@@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase):
# numbers from 1.
x = np.arange(1.0, total_size + 1, dtype=np.float32)
x = x.reshape(input_sizes)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
inputs = array_ops.placeholder(dtypes.float32)
t = pool_func(
inputs,
@@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase):
strides = [1] + strides + [1]
total_size = np.prod(input_sizes)
x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
with ops.device("CPU"):
diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
index 9fc94752ea660f7fb8b2c792180f01485ad04419..d03bd4fdbb7694bc36291faf9b845ec48e26a386 100644
--- a/tensorflow/compiler/tests/pooling_ops_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase):
# numbers from 1.
x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
x = x.reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
inputs = array_ops.placeholder(dtypes.float32)
t = inputs
@@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase):
# TODO(b/74222344): Fix nan handling for max pool grad.
# x[np.random.choice(total_size)] = np.nan
x = x.reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
with ops.device(self.CPU_DEVICE):
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
index 5fa7706d7294f2cffb7d24a56851be02d759335a..86536da7fed0e2309beb32fee9c7c605491592ed 100644
--- a/tensorflow/compiler/tests/powersign_test.py
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase):
base=math.e,
beta=0.9):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py
index cde87db63dbfd7c8d823c6fd0e41eee8b23735bb..c41b4171e26af4f7ad0237d7407a5b3691299595 100644
--- a/tensorflow/compiler/tests/proximal_adagrad_test.py
+++ b/tensorflow/compiler/tests/proximal_adagrad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad
class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
def testResourceProximalAdagradwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -60,7 +60,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertEqual(2, len(opt_vars))
def testProximalAdagradwithoutRegularization2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -84,7 +84,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval())
def testProximalAdagradWithL1(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -108,7 +108,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval())
def testProximalAdagradWithL1_L2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -151,7 +151,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
return var0.eval(), var1.eval()
def testEquivAdagradwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -159,7 +159,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1))
diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
index 11eb76871133eba8fcd24621afb03e16614fb005..3d808e6b8a71ef9fa60b671d07bfd907e9f58efc 100644
--- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py
+++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent
class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
def testResourceProximalGradientDescentwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([-0.09, -0.18]), var1.eval())
def testProximalGradientDescentwithoutRegularization2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.91, 2.82]), var1.eval())
def testProximalGradientDescentWithL1(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.67, 2.37]), var1.eval())
def testProximalGradientDescentWithL1_L2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
return var0.eval(), var1.eval()
def testEquivGradientDescentwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0))
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 1b969ee2b3886fca6ec9951d1621ca5af6a673d8..3a268978bfd72d08a7d3a7cc61a116dac543cda5 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tf = array_ops.placeholder(dtype)
with self.test_scope():
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 8c4e16e4e075726d741f6ff8cdfb6b1aad6cd33e..6e183441179ebf2e8c063b333f9328d6fa86cc88 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -39,7 +39,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = rng(dtype)
@@ -79,7 +79,7 @@ class RandomOpsTest(xla_test.XLATestCase):
if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
@@ -99,7 +99,7 @@ class RandomOpsTest(xla_test.XLATestCase):
count = 10000000
# TODO(b/34339814): implement inverse erf support for non-F32 types.
for dtype in [dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x)
@@ -147,7 +147,7 @@ class RandomOpsTest(xla_test.XLATestCase):
# TODO(b/26783907): this test requires the CPU backend to implement sort.
if self.device in ["XLA_CPU"]:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
@@ -158,7 +158,7 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllEqual(set(result), set(expected))
def testShuffle2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index cea2ec816f85e88b11e6e80c91c14fca9015f45c..5ae5b1bc1df76e6d0267a9a9ac18e7bc4725ec7b 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import itertools
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -30,22 +31,24 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ReduceOpsTest(xla_test.XLATestCase):
-
+@parameterized.named_parameters(('32_bit_index', dtypes.int32),
+ ('64_bit_index', dtypes.int64))
+class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _testReduction(self,
tf_reduce_fn,
np_reduce_fn,
dtype,
test_inputs,
+ index_dtype,
rtol=1e-4,
atol=1e-4):
"""Tests that the output of 'tf_reduce_fn' matches numpy's output."""
for test_input in test_inputs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
- index = array_ops.placeholder(dtypes.int32)
+ index = array_ops.placeholder(index_dtype)
out = tf_reduce_fn(a, index)
result = sess.run(out, {a: test_input, index: [0]})
self.assertAllClose(
@@ -89,22 +92,23 @@ class ReduceOpsTest(xla_test.XLATestCase):
np.array([[False, True, False], [True, True, False]]),
]
- def testReduceSumF32(self):
- self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA)
+ def testReduceSumF32(self, index_dtype):
+ self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
+ index_dtype)
- def testReduceSumC64(self):
+ def testReduceSumC64(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
- self.COMPLEX_DATA)
+ self.COMPLEX_DATA, index_dtype)
- def testReduceProdF32(self):
+ def testReduceProdF32(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceProdC64(self):
+ def testReduceProdC64(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
- self.COMPLEX_DATA)
+ self.COMPLEX_DATA, index_dtype)
- def testReduceMin(self):
+ def testReduceMin(self, index_dtype):
def reference_min(dtype, inp, axis):
"""Wrapper around np.amin that returns +infinity for an empty input."""
@@ -119,9 +123,9 @@ class ReduceOpsTest(xla_test.XLATestCase):
[np.float32, np.int32, np.int64]):
self._testReduction(math_ops.reduce_min,
functools.partial(reference_min, dtype), dtype,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceMax(self):
+ def testReduceMax(self, index_dtype):
def reference_max(dtype, inp, axis):
"""Wrapper around np.amax that returns -infinity for an empty input."""
@@ -137,23 +141,25 @@ class ReduceOpsTest(xla_test.XLATestCase):
[np.float32, np.int32, np.int64]):
self._testReduction(math_ops.reduce_max,
functools.partial(reference_max, dtype), dtype,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceMeanF32(self):
+ def testReduceMeanF32(self, index_dtype):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
- self.NONEMPTY_REAL_DATA)
+ self.NONEMPTY_REAL_DATA, index_dtype)
- def testReduceMeanC64(self):
+ def testReduceMeanC64(self, index_dtype):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
- self.NONEMPTY_COMPLEX_DATA)
+ self.NONEMPTY_COMPLEX_DATA, index_dtype)
- def testReduceAll(self):
- self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
+ def testReduceAll(self, index_dtype):
+ self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA,
+ index_dtype)
- def testReduceAny(self):
- self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+ def testReduceAny(self, index_dtype):
+ self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
+ index_dtype)
class ReduceOpPrecisionTest(xla_test.XLATestCase):
@@ -178,7 +184,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase):
"""
for test_input in test_inputs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32)
diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py
index c69b6837b0f88ced844faf3713a29a1c14c8790d..ff20ea3f4287b4666684501fa4920435a77b4183 100644
--- a/tensorflow/compiler/tests/reduce_window_test.py
+++ b/tensorflow/compiler/tests/reduce_window_test.py
@@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase):
"""Test cases for xla.reduce_window."""
def _reduce_window(self, operand, init, reducer, **kwargs):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(operand.dtype)
with self.test_scope():
output = xla.reduce_window(placeholder, init, reducer, **kwargs)
diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py
index 32ab5d08f0b925ee6b7b641ddba6b950149a6d20..392290fd92d0c7c928581422433892147374b2dd 100644
--- a/tensorflow/compiler/tests/reverse_ops_test.py
+++ b/tensorflow/compiler/tests/reverse_ops_test.py
@@ -51,7 +51,7 @@ class ReverseOpsTest(xla_test.XLATestCase):
def _AssertReverseEqual(self, revdims, shape):
np.random.seed(120)
pval = np.random.randint(0, 100, size=shape).astype(float)
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
p = array_ops.placeholder(dtypes.int32, shape=shape)
axis = constant_op.constant(
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index ccfa63001653537c4d1b7140e3d745c126f9034b..60c2337743b44e9bad61c4d65280eb2b1a1ad9ea 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
seq_lengths,
truth,
expected_err_re=None):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
with self.test_scope():
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index ff8bbac911abe73f946464663984ff1626302882..8840a1329a907bddc6ef1cb6dd1c2a6d234def5c 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
for centered in [False, True]:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 4292352e76ebcef7dbf41df7b857d2604a468117..897db384b7e8067b0460b5f344201f101a4d8479 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval(
feed_dict={p: x})
@@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes():
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype)
math_ops.cumsum(p, axis).eval(feed_dict={p: x})
@@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase):
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
@@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
prod = math_ops.cumprod(p, axis, exclusive, reverse)
tf_out = prod.eval(feed_dict={p: x})
@@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes():
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype)
math_ops.cumprod(x, axis).eval(feed_dict={p: x})
@@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py
index f606f88545d0b6f0b52cee9b93083a6bd91169bc..693f8513bc54e30060a2e963abd504768535a50a 100644
--- a/tensorflow/compiler/tests/scatter_nd_op_test.py
+++ b/tensorflow/compiler/tests/scatter_nd_op_test.py
@@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
def _runScatterNd(self, indices, updates, shape):
- with self.test_session():
+ with self.cached_session():
updates_placeholder = array_ops.placeholder(updates.dtype)
indices_placeholder = array_ops.placeholder(indices.dtype)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
index 772c20fd424577c3e06eeae409f424b77b52aa8a..287bb0d84e24de3bdcde3aa4c61acee00626e88f 100644
--- a/tensorflow/compiler/tests/segment_reduction_ops_test.py
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase):
"""Test cases for segment reduction ops."""
def _segmentReduction(self, op, data, indices, num_segments):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
d = array_ops.placeholder(data.dtype, shape=data.shape)
if isinstance(indices, int):
i = array_ops.placeholder(np.int32, shape=[])
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py
index 6c4890565d2083a9493abc59bd563c4dd9fdb186..2c611a959e1d71c53e44bc92c31258153d01507d 100644
--- a/tensorflow/compiler/tests/slice_ops_test.py
+++ b/tensorflow/compiler/tests/slice_ops_test.py
@@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.slice(i, [2], [4])
@@ -40,9 +40,22 @@ class SliceTest(xla_test.XLATestCase):
self.assertAllEqual([2, 3, 4, 5], result)
+ def testZeroSlice(self):
+ for dtype in self.numeric_types:
+ with self.cached_session():
+ i = array_ops.placeholder(dtype, shape=[2])
+ with self.test_scope():
+ o = array_ops.slice(i, [0], [0])
+ params = {
+ i: [0, 1],
+ }
+ result = o.eval(feed_dict=params)
+
+ self.assertAllEqual([], result)
+
def test3D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope():
o = array_ops.slice(i, [1, 2, 2], [1, 1, 4])
@@ -64,7 +77,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBegin(self):
"""Tests a slice where the start offset is not known at compile time."""
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope():
@@ -88,7 +101,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBeginAndNegativeSize(self):
"""Tests a slice where `begin` is fed dynamically and `size` contains -1."""
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope():
@@ -114,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.strided_slice(i, [2], [6], [2])
@@ -127,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1DNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.strided_slice(i, [6], [2], [-2])
@@ -140,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerate(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
o = array_ops.strided_slice(i, [-1, 0], [0, 3])
@@ -154,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerateNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1])
@@ -168,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope():
o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2])
@@ -189,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3DNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 4, 10])
with self.test_scope():
o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2])
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 7ff01be3cb4848d6bb85b8ab96b3ee1db6889791..51c04b5c4796474700a92a8b23a1cbdf533fcbb4 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -131,7 +131,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=4)
@@ -153,7 +153,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=6)
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index c685bc548f9f6f8f7723c6f94dfd45f5420b4a67..33b84cec7188c85a3bacb20a6df29c73adbd107c 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self.float_types:
# outputs = space_to_batch(inputs)
placeholder = array_ops.placeholder(dtype)
@@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase):
def _testPad(self, inputs, block_shape, paddings, outputs):
block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self.float_types:
# TODO(b/68813416): Skip bfloat16's as the input type for direct is
# float32 and results in a mismatch, while making testDirect provide the
diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
index 3db8101c4bfbb1b53c7318a36519612984d6f179..07afd1ab3fb78d5accc52ee2382af0b9fb8079d3 100644
--- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py
+++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
@@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices,
class SparseToDenseTest(xla_test.XLATestCase):
def testInt(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, 0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def testFloat(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
self.assertAllClose(np_ans, tf_ans)
def testSetValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1)
np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def testSetSingleValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, -1)
np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def test2d(self):
# pylint: disable=bad-whitespace
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1)
np_ans = np.array([[-1, -1, -1, -1],
[-1, -1, -1, 1],
@@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self):
- with self.test_session():
+ with self.cached_session():
x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0])
def test3d(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1)
np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
np_ans[1, 3, 0] = 1
@@ -91,25 +91,25 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans)
def testBadShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1)
def testBadValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[2,1\], "
r"should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [[5], [3]], -1)
def testBadNumValues(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [1, 2, 3], -1)
def testBadDefault(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError("default_value should be a scalar"):
_SparseToDense([1, 3], [5], [1, 2], [0])
diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py
index b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301..720595a159eea997be2246c4c7dad49612b257eb 100644
--- a/tensorflow/compiler/tests/stack_ops_test.py
+++ b/tensorflow/compiler/tests/stack_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import test
class StackOpTest(xla_test.XLATestCase):
def testStackPushPop(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
size = array_ops.placeholder(dtypes.int32)
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo")
@@ -41,7 +41,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]}))
def testStackPushPopSwap(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
a = np.arange(2000)
x = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
@@ -51,7 +51,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(a, c1.eval({x: a}))
def testMultiStack(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_push_v2(h1, v)
@@ -66,7 +66,7 @@ class StackOpTest(xla_test.XLATestCase):
def testSameNameStacks(self):
"""Different stacks with the same name do not interfere."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v1 = array_ops.placeholder(dtypes.float32)
v2 = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
@@ -84,14 +84,14 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(out2, 5.0)
def testCloseStack(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
size = array_ops.placeholder(dtypes.int32)
h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_close_v2(h)
sess.run(c1, {size: 5})
def testPushCloseStack(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops.stack_push_v2(h, v)
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index d162675ef840131485128414b4a29e3cd89c8761..1bea7d9355e40c5a71f848dabc0fa7fa760429d2 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -38,7 +38,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for stateless_op in [
@@ -55,7 +55,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testRandomUniformIsInRange(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_uniform(
@@ -74,7 +74,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000
@@ -88,7 +88,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._chi_squared(y, 10) < 16.92)
def testRandomNormalIsFinite(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_uniform(
@@ -111,7 +111,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomNormal(self):
"""Use Anderson-Darling test to test distribution appears normal."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000
@@ -126,7 +126,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testTruncatedNormalIsInRange(self):
# TODO(b/34339814): implement inverse erf support for non-F32 types.
for dtype in [dtypes.float32]:
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
x = stateless.stateless_truncated_normal(
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index f332aa2e9b97e13654cf9b10588c18fed32f7ad4..78244d0b366d9128a4c59f786e4c5ac12e743b75 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -44,7 +44,7 @@ def _make_converter(dtype):
class TensorArrayTest(xla_test.XLATestCase):
def testTensorArrayWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -66,7 +66,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -86,7 +86,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWritePack(dtype)
def testEmptyTensorArrayPack(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -100,7 +100,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([3, 0, 1], c0.eval().shape)
def _testTensorArrayWriteConcat(self, tf_dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -121,7 +121,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteConcat(dtype)
def _testTensorArrayUnpackRead(self, tf_dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -176,7 +176,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayUnpackReadMaybeLegacy()
def _testTensorArraySplitRead(self, tf_dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -228,7 +228,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArraySplitRead(dtype)
def testTensorGradArrayWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -261,7 +261,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[-2.0]], g_d2)
def testTensorGradArrayDynamicWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -300,7 +300,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(3, g_vs)
def testTensorGradAccessTwiceReceiveSameObject(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3,
element_shape=[1, 2])
@@ -317,7 +317,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -331,7 +331,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# the first type, but try to read the other type.
if len(self.float_types) > 1:
dtype1, dtype2 = list(self.float_types)[:2]
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtype1, tensor_array_name="foo", size=3)
@@ -347,7 +347,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w0.read(1)
def testTensorArraySplitIncompatibleShapesFails(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -379,7 +379,7 @@ class TensorArrayTest(xla_test.XLATestCase):
ta.split([1.0], [1]).flow.eval()
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
@@ -410,7 +410,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
def testMultiTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
h1 = tensor_array_ops.TensorArray(
size=1, dtype=dtypes.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
@@ -425,7 +425,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllClose(9.0, r.eval())
def _testTensorArrayGradientWriteReadType(self, dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.as_dtype(dtype),
tensor_array_name="foo",
@@ -478,7 +478,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -513,7 +513,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWritePackConcatAndRead()
def testTensorArrayReadTwice(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
ta_readtwice = tensor_array_ops.TensorArray(
@@ -529,7 +529,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
def _testTensorArrayGradientUnpackRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -557,7 +557,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientUnpackRead()
def testTensorArrayGradientSplitConcat(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2)
@@ -581,21 +581,21 @@ class TensorArrayTest(xla_test.XLATestCase):
grad_vals[0])
def testCloseTensorArray(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c1 = ta.close()
session.run(c1)
def testSizeTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
s = ta.size()
self.assertAllEqual(3, s.eval())
def testWriteCloseTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -608,7 +608,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# TODO(phawkins): implement while loops.
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
# np_dtype = dtype.as_numpy_dtype
- # with self.test_session() as session, self.test_scope():
+ # with self.cached_session() as session, self.test_scope():
# v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
# var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
# state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
@@ -692,7 +692,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# dynamic_size=True, dtype=dtypes.float32)
# def testGradSerialTwoLoops(self):
- # with self.test_session(), self.test_scope():
+ # with self.cached_session(), self.test_scope():
# num_steps = 100
# acc = tensor_array_ops.TensorArray(
# dtype=dtypes.float32,
@@ -725,7 +725,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# self.assertAllClose(31.0, grad.eval())
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
a = array_ops.identity(
np.arange(
3 * 5, dtype=np.float32).reshape(3, 5) + 1)
@@ -757,7 +757,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(joint_grad_b_t, g0)
def testWriteShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c0 = constant_op.constant([4.0, 5.0])
@@ -781,7 +781,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w0.write(0, c2)
def testPartlyUnknownShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=6)
@@ -821,7 +821,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
def _testUnpackShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -846,7 +846,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testUnpackShape()
def testSplitShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -867,7 +867,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def testWriteUnknownShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -879,7 +879,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def _testGradientWhenNotAllComponentsRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
x = constant_op.constant([2.0, 3.0])
w = ta.unstack(x)
@@ -893,7 +893,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testGradientWhenNotAllComponentsRead()
def _testTensorArrayEvalEmpty(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, infer_shape=False)
with self.assertRaisesOpError(
@@ -906,7 +906,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmpty()
def _testTensorArrayEvalEmptyWithDefault(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, infer_shape=True)
self.assertEqual(0, ta.size().eval())
@@ -921,7 +921,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmptyWithDefault()
def testTensorArrayScatterReadAndGradients(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -946,7 +946,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
def testTensorArrayWriteGatherAndGradients(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -974,7 +974,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(expected_grad, grad_vals[0])
def testTensorArrayIdentity(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
infer_shape=False)
ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index effa5a59fee7dda543b2c409dfaa27a972a55808..55a992195f2df72677b77757ae86171fa662439f 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest
class TernaryOpsTest(xla_test.XLATestCase):
def _testTernary(self, op, a, b, c, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 124cf9da813861fb3774e3bb29ad947af1598059..5b0e57f83ff4b5a8d1891bef0675074bd67addce 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name="a")
@@ -202,7 +202,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
# Disable float16 testing for now
if dtype != np.float16:
x = np.arange(-10, 10, 1).astype(dtype)
- with self.test_session() as session:
+ with self.cached_session() as session:
erf_x = session.run(math_ops.erf(x))
erfc_x = session.run(math_ops.erfc(x))
diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py
index b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720..4ee144beb7f3243be069d59ee4a613484fe183b3 100644
--- a/tensorflow/compiler/tests/while_test.py
+++ b/tensorflow/compiler/tests/while_test.py
@@ -43,7 +43,7 @@ class WhileTest(xla_test.XLATestCase):
def loop_cond(step):
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope():
loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)
@@ -65,7 +65,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.float32, [])
with self.test_scope():
@@ -91,7 +91,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.complex64, [])
with self.test_scope():
@@ -117,7 +117,7 @@ class WhileTest(xla_test.XLATestCase):
del x
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope():
loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 85084bb1240cf05f6eabfbea772df113cabe613c..28d61fb07dcb665fa0dbe3f3e566e291e24fa662 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types:
for shape in shapes:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape)
with self.test_scope():
@@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
])
shape = (10, 10)
for unsupported_dtype in test_types - self.all_types:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(unsupported_dtype, shape)
with self.test_scope():
@@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
pass
def testControlTrigger(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = gen_control_flow_ops.control_trigger()
sess.run(x)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f026df6c0c28fcbceaa0493871bc12c2d23b1f
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA op wrappers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected,
+ equality_fn=None):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ result = session.run(output, feeds)
+ if not equality_fn:
+ equality_fn = self.assertAllClose
+ equality_fn(result, expected, rtol=1e-3)
+
+ def testAdd(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.add,
+ args=(np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype)),
+ expected=np.array([5, 7, 9], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 9], [14, 15]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 13], [10, 15]], dtype=dtype))
+
+ def testBroadcast(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.broadcast(x, (7, 42)),
+ args=(v,),
+ expected=np.tile(v, (7, 42, 1, 1)))
+
+ def testShiftRightLogical(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
+
+ def testShiftRightArithmetic(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([-1, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
+
+ PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT,
+ xla_data_pb2.PrecisionConfigProto.HIGH,
+ xla_data_pb2.PrecisionConfigProto.HIGHEST)
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testConv(self, precision):
+ for dtype in set(self.float_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ def conv_1d_fn(lhs, rhs):
+ dnums = xla_data_pb2.ConvolutionDimensionNumbers()
+ num_spatial_dims = 1
+ dnums.input_batch_dimension = 0
+ dnums.input_feature_dimension = 1
+ dnums.output_batch_dimension = 0
+ dnums.output_feature_dimension = 1
+ dnums.kernel_output_feature_dimension = 0
+ dnums.kernel_input_feature_dimension = 1
+ dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.conv(
+ lhs,
+ rhs,
+ window_strides=(1,),
+ padding=((2, 1),),
+ lhs_dilation=(1,),
+ rhs_dilation=(2,),
+ dimension_numbers=dnums)
+
+ self._assertOpOutputMatchesExpected(
+ conv_1d_fn,
+ args=(
+ np.array([[[3, 4, 5, 6]]], dtype=dtype),
+ np.array([[[-2, -3]]], dtype=dtype),
+ ),
+ expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testDotGeneral(self, precision):
+ for dtype in self.float_types:
+
+ def dot_fn(lhs, rhs):
+ dnums = xla_data_pb2.DotDimensionNumbers()
+ dnums.lhs_contracting_dimensions.append(2)
+ dnums.rhs_contracting_dimensions.append(1)
+ dnums.lhs_batch_dimensions.append(0)
+ dnums.rhs_batch_dimensions.append(0)
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.dot_general(
+ lhs,
+ rhs,
+ dimension_numbers=dnums,
+ precision_config=precision_config)
+
+ lhs = np.array(
+ [
+ [[1, 2], [3, 4]],
+ [[5, 6], [7, 8]],
+ ], dtype=dtype)
+ rhs = np.array(
+ [
+ [[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]],
+ ], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ dot_fn,
+ args=(lhs, rhs),
+ expected=np.array(
+ [
+ [[9, 12, 15], [19, 26, 33]],
+ [[95, 106, 117], [129, 144, 159]],
+ ],
+ dtype=dtype))
+
+ def testNeg(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.neg,
+ args=(np.array([1, 2, 3], dtype=dtype),),
+ expected=np.array([-1, -2, -3], dtype=dtype))
+
+ def testPad(self):
+ for dtype in self.numeric_types:
+
+ def pad_fn(x):
+ return xla.pad(
+ x,
+ padding_value=7,
+ padding_low=[2, 1],
+ padding_high=[1, 2],
+ padding_interior=[1, 0])
+
+ self._assertOpOutputMatchesExpected(
+ pad_fn,
+ args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),),
+ expected=np.array(
+ [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7],
+ [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
+ dtype=dtype))
+
+ def testReduce(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def sum_reducer(x, y):
+ return x + y
+
+ def sum_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([12, 15, 18, 21], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([6, 22, 38], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0, 1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=dtype(66))
+
+ @function.Defun(dtype, dtype)
+ def mul_reducer(x, y):
+ return x * y
+
+ def mul_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ mul_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([0, 45, 120, 231], dtype=dtype))
+
+ def testSelectAndScatter(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def add_scatter(x, y):
+ return x + y
+
+ @function.Defun(dtype, dtype)
+ def ge_select(x, y):
+ return x >= y
+
+ def test_fn(operand, source):
+ return xla.select_and_scatter(
+ operand,
+ window_dimensions=[2, 3, 1, 1],
+ window_strides=[2, 2, 1, 1],
+ padding=[[0, 0]] * 4,
+ source=source,
+ init_value=0,
+ select=ge_select,
+ scatter=add_scatter)
+
+ self._assertOpOutputMatchesExpected(
+ test_fn,
+ args=(np.array(
+ [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6],
+ [0, 6, 2, 10, 2]],
+ dtype=dtype).reshape((4, 5, 1, 1)),
+ np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))),
+ expected=np.array(
+ [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0],
+ [0, 0, 0, 1, 0]],
+ dtype=dtype).reshape((4, 5, 1, 1)))
+
+ def testTranspose(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 85fd0c9217d8e56be564f915e5c950d4fadc4e59..92e577bb7b930f5b9139e361cafb8628daede455 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -39,6 +39,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -88,6 +89,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -221,13 +223,11 @@ cc_library(
srcs = [
"literal_util.cc",
"shape_util.cc",
- "str_util.cc",
"type_util.cc",
],
hdrs = [
"literal_util.h",
"shape_util.h",
- "str_util.h",
"type_util.h",
],
visibility = [":friends"],
@@ -256,6 +256,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -307,6 +308,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -374,19 +376,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- ],
-)
-
-tf_cc_test(
- name = "str_util_test",
- srcs = [
- "str_util_test.cc",
- ],
- deps = [
- ":common",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -459,6 +449,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -482,6 +473,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -609,3 +601,30 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "resource_operation_table",
+ srcs = ["resource_operation_table.cc"],
+ hdrs = ["resource_operation_table.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_table_test",
+ srcs = ["resource_operation_table_test.cc"],
+ deps = [
+ ":resource_operation_table",
+ ":xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index de1008803d69fefa415c7bdbe6c27a62e625b417..e8673d77903bd5a1a85412e9dfa86437f73d56bc 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -23,11 +23,11 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
-
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
- std::vector* compile_time_const_args) {
+ std::vector* compile_time_const_args,
+ std::vector* compile_time_const_nodes) {
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set metadata_ops = {
"Rank",
@@ -36,9 +36,16 @@ Status BackwardsConstAnalysis(const Graph& g,
"Size",
};
+ std::vector compile_time_const_nodes_impl;
+ if (compile_time_const_nodes) {
+ CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
+ } else {
+ compile_time_const_nodes_impl.resize(g.num_node_ids());
+ compile_time_const_nodes = &compile_time_const_nodes_impl;
+ }
+
Status status;
- std::unordered_set must_be_const;
- auto visit = [&status, &metadata_ops, &must_be_const,
+ auto visit = [&status, &metadata_ops, compile_time_const_nodes,
compile_time_const_args](Node* node) {
if (!status.ok()) return;
@@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g,
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
- if (must_be_const.find(node) != must_be_const.end()) {
+ if ((*compile_time_const_nodes)[node->id()]) {
if (node->type_string() == "_Arg") {
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
if (!status.ok()) return;
- compile_time_const_args->at(index) = true;
+ if (compile_time_const_args) {
+ (*compile_time_const_args)[index] = true;
+ }
return;
}
for (const Edge* pred : node->in_edges()) {
if (!pred->IsControlEdge()) {
- must_be_const.insert(pred->src());
+ (*compile_time_const_nodes)[pred->src()->id()] = true;
}
}
return;
@@ -80,7 +89,7 @@ Status BackwardsConstAnalysis(const Graph& g,
for (Edge const* edge : node->in_edges()) {
if (edge->dst_input() >= name_range->second.first &&
edge->dst_input() < name_range->second.second) {
- must_be_const.insert(edge->src());
+ (*compile_time_const_nodes)[edge->src()->id()] = true;
}
}
}
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
index 634b97d7e3760c0344c948a56353ade243284aa6..af57e5a4033248e3fd32dabeda252c4ca0a44050 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.h
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -23,10 +23,18 @@ limitations under the License.
namespace tensorflow {
-// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that
-// must be compile-time constants.
+// Backwards dataflow analysis that finds nodes in a graph that must be
+// compile-time constants for us to be able to lower the graph to XLA.
+//
+// The indices of the arguments to `graph` that must be constant are returned in
+// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not
+// null.
+//
+// The ids of the nodes in `graph` that must be constant are returned in
+// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
Status BackwardsConstAnalysis(const Graph& graph,
- std::vector* compile_time_const_args);
+ std::vector* compile_time_const_arg_indices,
+ std::vector* compile_time_const_nodes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc
index 992b12c06db5efc0ae54284d0ea77017c1c79aca..56065be894697bc72ecc0089c665c19aafee7bf8 100644
--- a/tensorflow/compiler/tf2xla/const_analysis_test.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) {
auto c = ops::Reshape(root, arg2, b);
auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
- Graph graph(OpRegistry::Global());
- TF_ASSERT_OK(root.ToGraph(&graph));
+ FixupSourceAndSinkEdges(root.graph());
std::vector const_args(4, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ std::vector const_nodes(root.graph()->num_node_ids(), false);
+ TF_ASSERT_OK(
+ BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes));
// Arg 0 doesn't need to be constant since the graph only uses its shape.
// Arg 1 must be constant because it flows to the shape argument of a Reshape.
// Arg 2 is used only as the value input to a Reshape and need not be const.
// Arg 3 is used as the reduction-indices argument to Sum and must be const.
EXPECT_EQ(const_args, std::vector({false, true, false, true}));
+
+ EXPECT_FALSE(const_nodes[arg0.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg1.node()->id()]);
+ EXPECT_FALSE(const_nodes[arg2.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg3.node()->id()]);
}
// Regression test for a case where the backward const analysis did
@@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector const_args(3, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector({true, true, false}));
}
@@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector const_args(2, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector({false, true}));
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index f14cfca4eaf654abc1d37c8abd34fbdae2bd24d7..b5667ca0d3ba35bea9da2d702b5b49fb38fe6f02 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include
#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -52,11 +53,10 @@ string DebugString(CondStateMap::CondId cond_state) {
if (cond_state == nullptr || cond_state->empty()) return "[]";
return strings::StrCat(
"[",
- tensorflow::str_util::Join(
- *cond_state, ", ",
- [](string* output, const CondStateMap::CondNode& node) {
- strings::StrAppend(output, node.ToString());
- }),
+ absl::StrJoin(*cond_state, ", ",
+ [](string* output, const CondStateMap::CondNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
"]");
}
@@ -169,10 +169,10 @@ using CondArgNodes = std::vector;
string DebugString(const CondArgNodes& nodes) {
return strings::StrCat(
"[",
- tensorflow::str_util::Join(nodes, ", ",
- [](string* output, const CondArgNode& node) {
- strings::StrAppend(output, node.ToString());
- }),
+ absl::StrJoin(nodes, ", ",
+ [](string* output, const CondArgNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
"]");
}
@@ -387,8 +387,9 @@ Status Conditional::BuildArgumentNodes() {
}
if (!has_input) {
return errors::Internal(
- "Failed to functionalize control flow with merge '", m->name(),
- "' that doesn't have input on ", Branch_Name(branch), " branch.");
+ "Failed to functionalize control flow with merge ",
+ FormatNodeForError(*m), " that doesn't have input on ",
+ Branch_Name(branch), " branch.");
}
}
}
@@ -469,8 +470,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
// but revisit to improve the testing to enable making this an
// error.
LOG(WARNING) << errors::InvalidArgument(
- "Graph contains node ", src->name(), " that feeds into node ",
- dst->name(),
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
" but these nodes are in different control contexts (",
DebugString(src_id), " vs ", DebugString(dst_id),
" (detected during out edge testing)");
@@ -512,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
node_map.at(src->id()) = output->CopyNode(src);
} else {
return errors::InvalidArgument(
- "Graph contains node ", src->name(), " that feeds into node ",
- dst->name(),
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
" but these nodes are in different control contexts (",
DebugString(src_id), " vs ", DebugString(dst_id),
" (detected during in edge testing)");
@@ -675,7 +676,8 @@ Status Conditional::AddOutputEdges(Graph* graph) {
int dst_input = edge->dst_input();
if (edge->src_output() > 0) {
return errors::Unimplemented("Output of index (", edge->src_output(),
- ") of merge node ", node->name());
+ ") of merge node ",
+ FormatNodeForError(*node));
}
bool control_edge = edge->IsControlEdge();
@@ -1060,7 +1062,8 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
CondStateMap::CondId prop = StateAlongEdge(e);
auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst));
- TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
cond_state_map_.ResetId(dst, id_or.ValueOrDie());
}
@@ -1090,7 +1093,8 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) {
// Joining the state between the current and propagated state.
CondStateMap::CondId prop = StateAlongEdge(e);
auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst));
- TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
cond_state_map_.ResetId(dst, id_or.ValueOrDie());
}
}
@@ -1117,7 +1121,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
}
if (non_dead_edge == nullptr) {
- return errors::InvalidArgument("Merge node ", node->name(),
+ return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
" has no non-dead inputs.");
}
cond_state_map_.MarkDead(node);
@@ -1169,7 +1173,8 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
if (IsMerge(dst_node)) {
auto id_or =
JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node));
- TF_RETURN_IF_ERROR(id_or.status());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst_node));
cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
} else {
auto id_or =
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
index a0544b69e9ea3a1bd16dcd08bc4b4638a8fc31fb..61940e3586c59ffc660eaac8f8d035fbbbdfeffd 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/graph.h"
@@ -43,11 +44,11 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index);
template
string NodesToString(const T& nodes) {
return strings::StrCat("{",
- str_util::Join(nodes, ",",
- [](string* output, const Node* node) {
- strings::StrAppend(output,
- node->name());
- }),
+ absl::StrJoin(nodes, ",",
+ [](string* output, const Node* node) {
+ strings::StrAppend(output,
+ node->name());
+ }),
"}");
}
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index e4fdf0a6186eb69a2e3413838c91616b992ef2d6..1ed1fb3b021b27be00086b2e71cc9309e3d76049 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
std::vector compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
- BackwardsConstAnalysis(*graph, &compile_time_constant_flags));
+ BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
+ /*compile_time_const_nodes=*/nullptr));
args->resize(expressions.size());
for (int i = 0; i < args->size(); ++i) {
@@ -145,6 +146,7 @@ Status GraphCompiler::Compile() {
}
OpKernelContext op_context(¶ms, n->num_outputs());
+ VLOG(3) << "Translating " << params.op_kernel->name();
if (IsFunctional(n)) {
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
} else {
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index b1366e9e31e28406c5bf1a808b9c5670558ed9c7..c1438f893f6d3c46dd7f6c39b6aa3367a79789f0 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -22,6 +22,7 @@ tf_kernel_library(
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
+ "broadcast_to_op.cc",
"bucketize_op.cc",
"cast_op.cc",
"categorical_op.cc",
@@ -100,6 +101,12 @@ tf_kernel_library(
"unary_ops.cc",
"unpack_op.cc",
"variable_ops.cc",
+ "xla_broadcast_helper_op.cc",
+ "xla_conv_op.cc",
+ "xla_dot_op.cc",
+ "xla_pad_op.cc",
+ "xla_reduce_op.cc",
+ "xla_select_and_scatter_op.cc",
],
hdrs = [
"index_ops.h",
@@ -108,6 +115,8 @@ tf_kernel_library(
deps = [
":if_op",
":while_op",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index ba3b1c9dab79a387c48e8e25e4804917f328f8a0..2e383b1473590403823863f89264e5381d8e8806 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for broadcasting used in gradient
// code.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
const int64 len = bcast.output_shape().size();
Tensor output(DT_INT32, TensorShape({len}));
@@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
Output(ctx, 0, bcast.grad_x_reduce_idx());
Output(ctx, 1, bcast.grad_y_reduce_idx());
}
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..4bd7c74dca2a7cbb51f2a329ac575d635f314516
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+namespace {
+
+class BroadcastToOp : public XlaOpKernel {
+ public:
+ explicit BroadcastToOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ TensorShape output_shape;
+ OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+
+ OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
+ errors::InvalidArgument(
+ "Input rank (", input_shape.dims(),
+ ") must be less than or equal to the output rank (",
+ output_shape.dims(), ")"));
+
+ auto input_dims = input_shape.dim_sizes();
+ auto output_dims = output_shape.dim_sizes();
+
+ // Broadcasting is done right-to-left on right-aligned dimensions; reverse
+ // the two vectors so elements to be broadcast are aligned.
+ absl::c_reverse(input_dims);
+ absl::c_reverse(output_dims);
+
+ std::vector broadcast_dims;
+ std::vector broadcast_shape;
+ for (int i = 0; i < output_shape.dims(); ++i) {
+ if (i < input_shape.dims()) {
+ OP_REQUIRES(
+ context,
+ (output_dims[i] == 0 && input_dims[i] == 0) ||
+ (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
+ errors::InvalidArgument("invalid shape to broadcast from ",
+ input_shape.DebugString(), " to ",
+ output_shape.DebugString()));
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ if (output_dims[i] != input_dims[i]) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(input_dims[i]);
+ broadcast_shape.push_back(output_dims[i] / input_dims[i]);
+ }
+ } else {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ }
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::Reshape(
+ xla::BroadcastInDim(context->Input(0),
+ xla::ShapeUtil::MakeShape(
+ context->input_xla_type(0), broadcast_shape),
+ broadcast_dims),
+ output_shape.dim_sizes());
+ context->SetOutput(0, output);
+ }
+};
+
+REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"),
+ BroadcastToOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index ed44ad218b6dc073583ec339da082b6881ad672d..70c3eaf66bbd6470734d1e5fc9978510022ac7bc 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -178,7 +178,7 @@ class MatrixDiagOp : public XlaOpKernel {
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
tensorflow::gtl::ArraySlice other_dims(dims);
- other_dims.pop_back();
+ other_dims.remove_suffix(1);
xla::XlaOp input = ctx->Input(0);
xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 8d75624e74028ea083c3facc4f9578ec14c50e6d..8e071bf0b7ae638888818ea8cd5d63b5d543342e 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -32,13 +32,13 @@ namespace {
//
// 1. S := (N - 1) / gcd(N-1, R-1)
// 2. k := (R - 1) / gcd(N-1, R-1)
-// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1)
+// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
//
// For example, to Scale from 7x7 -> 15x15:
//
// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
-// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2)
+// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
//
//
// The 7x7 -> 15x15 case is much too large to write out in full as an
@@ -65,6 +65,8 @@ namespace {
// 1/9 * 3 6 9 6 3
// 2 4 6 4 2
// 1 2 3 2 1
+// Note that the convolution kernel matrix is separable and thus we can instead
+// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
// Computes the size of the convolutional kernel and stride to use when resizing
// from in_size to out_size.
@@ -76,7 +78,8 @@ struct ResizeConvolutionDims {
std::vector stride;
};
ResizeConvolutionDims ComputeResizeConvolutionParameters(
- gtl::ArraySlice in_size, gtl::ArraySlice out_size) {
+ gtl::ArraySlice in_size, gtl::ArraySlice out_size,
+ bool align_corners) {
CHECK_EQ(in_size.size(), out_size.size());
int num_spatial_dims = in_size.size();
ResizeConvolutionDims dims;
@@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
// entry before resizing.
dims.stride[i] = dims.kernel_size[i] = 1;
} else {
- int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1),
- static_cast(out_size[i] - 1));
- dims.stride[i] = (in_size[i] - 1) / gcd;
- dims.kernel_size[i] = (out_size[i] - 1) / gcd;
+ // The scaling factor changes depending on the alignment of corners.
+ const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i];
+ const int64 out_size_factor =
+ align_corners ? out_size[i] - 1 : out_size[i];
+
+ int64 gcd = MathUtil::GCD(static_cast(in_size_factor),
+ static_cast(out_size_factor));
+ dims.stride[i] = in_size_factor / gcd;
+ dims.kernel_size[i] = out_size_factor / gcd;
}
}
return dims;
}
+// The upper padding of the input needed by ConvGeneralDilated calls is
+// determined by solving two related relationships (assuming rhs_dilation == 0):
+// 1. dilated_input_dim = lower_padding + upper_padding
+// + lhs_dilation * (in_size - 1) + 1
+// 2. dilated_input_dim = (2 * dims.kernel-size - 1)
+// + dims.stride * (out_size - 1)
+int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
+ int64 stride) {
+ return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) -
+ 1 - (kernel_size * (in_size - 1));
+}
+
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
@@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector in_size,
std::vector out_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
// Picture for a 1x3 to 1x4 resize:
// stride = 2, kernel size = 3
// Input:
@@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, out_size);
+ ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
xla::XlaOp output;
- // Split convolutions into independent dimensions if they wmuld be a very
+
+ // Concatenation and padding below currently assumes num_spatial_dims is 2 to
+ // prevent needless code complexity.
+ CHECK_EQ(num_spatial_dims, 2)
+ << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
+ std::vector upper_padding(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ upper_padding[i] = dims.kernel_size[i] - 1;
+ }
+ xla::XlaOp input_data = input;
+
+ if (!align_corners) {
+ // When Tensorflow does not align_corners, the resize indexing can access
+ // beyond the upper bound and is instead clamped to prevent out of bounds
+ // reads. This is conceptually the same as extending the edges of the input.
+ // We emulate this by copying the last row/column of the input.
+ // Calculate what padding would be needed then determine how far to extend
+ // the border before lhs dilation.
+ std::vector num_extended(num_spatial_dims);
+ upper_padding[0] = CalculateUpperPadding(
+ in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] = CalculateUpperPadding(
+ in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
+ num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
+ num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
+
+ if (num_extended[0] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, in_size[0] - 1, 0, 0},
+ {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[0]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
+ }
+ }
+
+ if (num_extended[1] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, 0, in_size[1] - 1, 0},
+ {1, in_size[0] + num_extended[0], in_size[1], channels},
+ {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[1]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
+ }
+ }
+
+ // Setting in_size to (in_size + num_extended) due to the above Slice and
+ // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
+ upper_padding[0] =
+ CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
+ dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] =
+ CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
+ dims.kernel_size[1], dims.stride[1]);
+ }
+
+ // Split convolutions into independent dimensions if they would be a very
// large kernel.
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- output = xla::ConvGeneralDilated(
- input, kernel, dims.stride,
- /*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
- {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
- /*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ output =
+ xla::ConvGeneralDilated(input_data, kernel, dims.stride,
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, upper_padding[0]},
+ {dims.kernel_size[1] - 1, upper_padding[1]}},
+ /*lhs_dilation=*/dims.kernel_size,
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
output = xla::ConvGeneralDilated(
- input, kernel0, {dims.stride[0], 1},
+ input_data, kernel0, {dims.stride[0], 1},
/*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
+ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp kernel1 =
@@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
output = xla::ConvGeneralDilated(
output, kernel1, {1, dims.stride[1]},
/*padding=*/
- {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
}
@@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector in_size,
std::vector grad_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, grad_size);
+ ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
// To form the backward convolution, we keep the kernel unchanged (it is
// already symmetric) and swap the roles of strides and LHS dilation.
@@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel {
public:
explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
- OP_REQUIRES(
- ctx, align_corners_ == true,
- errors::Unimplemented(
- "ResizeBilinear with align_corners=False is not yet implemented"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel {
// If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
// dimension i.
- std::vector slice_size = in_size;
bool slice_input = false;
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] > 1 && out_size[i] == 1) {
// If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
// entry before resizing.
slice_input = true;
- slice_size[i] = 1;
+ in_size[i] = 1;
}
}
if (slice_input) {
- input = xla::Slice(input, {0, 0, 0, 0},
- {batch, slice_size[0], slice_size[1], channels},
- {1, 1, 1, 1});
+ input =
+ xla::Slice(input, {0, 0, 0, 0},
+ {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
}
// Output is always type float.
@@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel {
// operations along different dimensions.
// Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd.
+ // This does not work in the case of align_corners_=false because of special
+ // padding requirements that cause multiple resizes to be very different
+ // from a single resize.
//
// This makes the convolutions kernels smaller and the operation faster.
xla::XlaOp output = input;
@@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel {
(static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2),
(static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)};
if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
- k[0] > 1 && k[1] > 1) {
+ k[0] > 1 && k[1] > 1 && align_corners_) {
std::vector next_out_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, next_out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, next_out_size,
+ channels, align_corners_);
input = output;
in_size = next_out_size;
} else {
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, out_size,
+ channels, align_corners_);
in_size = out_size;
}
} else {
output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
- in_size, out_size, channels);
+ in_size, out_size, channels,
+ align_corners_);
in_size = out_size;
}
}
@@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel {
std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, next_grad_size, channels);
+ b, grad, num_spatial_dims, in_size, next_grad_size, channels,
+ align_corners_);
grad = output;
in_size = next_grad_size;
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2..2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr DoMirrorPad(const xla::XlaOp& t,
const xla::Shape& original_shape,
const xla::LiteralSlice& pad_literal,
+ const MirrorPadMode mode,
xla::XlaBuilder* b) {
+ // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT
+ // will not mirror the border values while symmetric does.
+ // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is:
+ // - [1, 2, 3, 2, 1] in reflect mode
+ // - [1, 2, 3, 3, 2] in symmetric mode.
+ int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0;
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
@@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel {
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
pad_literal.GetIntegralAsS64({dimno, 1}));
int64 dim_size = original_shape.dimensions(dimno);
- auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding,
- dim_size - 1, 1, dimno);
- auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
+
+ // Padding amounts on each side must be no more than the size of the
+ // original shape.
+ TF_RET_CHECK(lhs_padding >= 0 &&
+ lhs_padding <= dim_size - excluded_edges);
+ TF_RET_CHECK(rhs_padding >= 0 &&
+ rhs_padding <= dim_size - excluded_edges);
+
+ auto lhs_pad =
+ xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding,
+ dim_size - excluded_edges, 1, dimno);
+ auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges,
+ excluded_edges + rhs_padding, 1, dimno);
accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno);
}
return accum;
@@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel {
MirrorPadMode mode;
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
- OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT,
- xla::Unimplemented(
- "Only REFLECT MirrorPad mode is currently supported"));
+ OP_REQUIRES(
+ ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC,
+ xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and "
+ "REFLECT modes are currently supported"));
const int dims = input_shape.dims();
OP_REQUIRES(
@@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
xla::StatusOr accum_status =
- DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b);
+ DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b);
OP_REQUIRES_OK(ctx, accum_status.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index d4d180aff806f12875f0e43f111ee090f6607ef6..f6f158a73be42ea2602811ad64a2a2c655dab088 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
-// Divide each element of an image by the count of elements that contributed to
-// that element during pooling.
-static xla::XlaOp AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape, xla::Padding padding,
- const std::vector& ksize, const std::vector& stride,
- int num_spatial_dims, TensorFormat data_format) {
- if (padding == xla::Padding::kValid) {
- // In VALID padding, all windows have the same number of elements
- // contributing to each average. Divide by the window size everywhere to
- // get the average.
- int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
- [](int64 a, int64 b) { return a * b; });
-
- auto divisor =
- XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return xla::Div(output, divisor);
- } else {
- // For SAME padding, the padding shouldn't be included in the
- // counts. We use another ReduceWindow to find the right counts.
-
- // TODO(phawkins): use a less brute-force way to compute this. Only
- // the boundary regions will have interesting values here.
-
- std::vector input_dim_sizes(num_spatial_dims);
- std::vector window_dims(num_spatial_dims);
- std::vector window_ksize(num_spatial_dims);
- std::vector window_stride(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
- input_dim_sizes[i] = input_shape.dim_size(dim);
- window_dims[i] = dim;
- window_ksize[i] = ksize[dim];
- window_stride[i] = stride[dim];
- }
-
- // Build a matrix of all 1s, with the same width/height as the input.
- const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = xla::Broadcast(
- XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
-
- // Perform a ReduceWindow with the same window size, strides, and padding
- // to count the number of contributions to each result element.
- auto reduce = xla::ReduceWindow(
- ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
- xla::Padding::kSame);
- auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
-
- return xla::Div(output, counts, window_dims);
- }
-}
-
class AvgPoolOp : public PoolingOp {
public:
AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
@@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("out_backprop must be ", num_dims(),
"-dimensional"));
- int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- int64 depth = out_backprop_shape.dim_size(depth_dim);
-
- // We can think of average-pooling as:
- // * a convolution with a kernel consisting entirely of 1s, where the
- // input feature and output feature are equal, and 0s everywhere else.
- // * followed by dividing by the counts.
- //
- // This then gives us an algorithm to build the gradient:
- // * divide out_backprop by the counts, followed by
- // * Conv2DBackpropInput specialized for that kernel, which simplifies to
- // a Pad and a ReduceWindow.
- //
- // For an explanation of backpropagation for convolution, see the comments
- // in third_party/tensorflow/core/kernels/conv_grad_ops.h
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- std::vector filter_dims(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- filter_dims[i] = ksize_[dim];
- }
- filter_dims[num_dims() - 2] = depth;
- filter_dims[num_dims() - 1] = depth;
- TensorShape filter_shape(filter_dims);
-
- // Reuse the logic from Conv2DBackpropInput to compute padding.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(
- ctx, ConvBackpropComputeDimensions(
- type_string(), /*num_spatial_dims=*/num_spatial_dims_,
- gradients_shape, filter_shape, out_backprop_shape, stride_,
- padding_, data_format_, &dims));
-
- // The input gradients are computed by a convolution of the output gradients
- // and the filter, with some appropriate padding. See the comment at the top
- // of conv_grad_ops.h for details.
- xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
- auto dtype = input_type(1);
+ std::vector stride_int64s(stride_.begin(), stride_.end());
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
-
- // Divide the out_backprop values by the counts for each spatial position.
- std::vector stride_int64s(stride_.begin(), stride_.end());
- auto out_backprop_div = AvgPoolDivideByCount(
- ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
- stride_int64s, num_spatial_dims_, data_format_);
-
- // Pad the gradients in the spatial dimensions. We use the same padding
- // as Conv2DBackpropInput.
- xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- auto* padding = padding_config.mutable_dimensions(dim);
- padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
- padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
- padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
- }
-
- auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
-
- // in_backprop = padded_gradients ones
- std::vector ones(num_dims(), 1LL);
- auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = xla::ReduceWindow(
- XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), ksize_,
- /* window_strides=*/ones, xla::Padding::kValid);
- ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype));
+ xla::PrimitiveType xla_reduction_type;
+ auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
+ auto converted_out_backprop =
+ xla::ConvertElementType(out_backprop, xla_reduction_type);
+ auto xla_data_format =
+ XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
+ auto padding_values =
+ MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
+ xla_padding, xla_data_format);
+ auto in_backprop =
+ xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
+ ksize_, stride_int64s, padding_values, xla_data_format,
+ /*counts_include_padding=*/padding_ == VALID);
+ // Convert the pooling result back to the input type before returning it.
+ xla::PrimitiveType xla_out_backprop_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
+ &xla_out_backprop_type));
+ ctx->SetOutput(0,
+ xla::ConvertElementType(in_backprop, xla_out_backprop_type));
}
protected:
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index b11a4ce36da9907ce8fe377c075023a4540797fa..8102faad28db71075fb8da269c55edbdb667193e 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel {
explicit ReduceWindowOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_dimensions", &window_dimensions_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_strides", &window_strides_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_));
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const DataType dtype = context->input_type(0);
+ std::vector window_dimensions;
+ std::vector window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
const int rank = input_shape.dims();
- OP_REQUIRES(context, rank == window_dimensions_.size(),
+ OP_REQUIRES(context, rank == window_dimensions.size(),
errors::InvalidArgument(
"The size of window_dimensions must be equal to the input "
"rank (",
- window_dimensions_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == window_strides_.size(),
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
errors::InvalidArgument(
"The size of window_strides must be equal to the input "
"rank (",
- window_strides_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_low_.size(),
- errors::InvalidArgument(
- "The size of padding_low must be equal to the input "
- "rank (",
- padding_low_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_high_.size(),
- errors::InvalidArgument(
- "The size of padding_high must be equal to the input "
- "rank (",
- padding_high_.size(), " vs. ", rank, ")"));
-
- xla::XlaBuilder* builder = context->builder();
+ window_strides.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel {
compile_options.use_tuple_arg = false;
compile_options.resolve_compile_time_constants = false;
compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult reducer;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *computation_,
@@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel {
xla::Shape scalar_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of ReduceWindow reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
OP_REQUIRES(context,
- xla::ShapeUtil::Compatible(
- reducer.xla_output_shape,
- xla::ShapeUtil::MakeTupleShape({scalar_shape})),
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
errors::InvalidArgument(
- "Invalid output shape of ReduceWindow reducer. Expected ",
- xla::ShapeUtil::HumanString(scalar_shape), " got ",
- xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
-
- // Wraps the reducer in a computation that unpacks the output tuple.
- xla::XlaComputation wrapper;
- {
- std::unique_ptr cb =
- builder->CreateSubBuilder("wrapper");
- auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
- auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
- auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
- xla::GetTupleElement(outputs, 0);
- xla::StatusOr result = cb->Build();
- OP_REQUIRES_OK(context, result.status());
- wrapper = std::move(result.ValueOrDie());
- }
-
- std::vector> padding(rank);
- for (int i = 0; i < rank; ++i) {
- padding[i] = {padding_low_[i], padding_high_[i]};
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get({i, 0}),
+ padding_literal.Get({i, 1})};
}
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
- context->Input(0), context->Input(1), wrapper, window_dimensions_,
- window_strides_, padding);
+ context->Input(0), context->Input(1), *reducer.computation,
+ window_dimensions, window_strides, padding);
context->SetOutput(0, output);
}
private:
const NameAttrList* computation_;
- std::vector window_dimensions_;
- std::vector window_strides_;
- std::vector padding_low_;
- std::vector padding_high_;
TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp);
};
-REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp);
+REGISTER_XLA_OP(Name("XlaReduceWindow")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ ReduceWindowOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index b52f0a0ab6290f2019bb58120be5c2364ec15bb6..598248563bb93146e6dea3016822d26b8bf368e7 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific reduction Ops.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -29,9 +30,6 @@ namespace tensorflow {
XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
DataType reduction_type)
: XlaOpKernel(ctx), reduction_type_(reduction_type) {
- const DataType dt = BaseType(input_type(0));
- OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
-
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
OP_REQUIRES_OK(
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
@@ -58,20 +56,24 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
return;
}
+ OP_REQUIRES(ctx, axes_tensor_shape.dims() <= 1,
+ errors::InvalidArgument(
+ "Expected scalar or vector as index argument, got ",
+ axes_tensor_shape.DebugString()));
+
// Evaluate the constant, reshaping to a 1-vector if it is a scalar.
+ std::vector axes;
xla::Literal axes_literal;
- OP_REQUIRES_OK(
- ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()},
- &axes_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes));
VLOG(1) << "data shape: " << data_shape.DebugString();
- VLOG(1) << "axes : " << axes_literal.ToString();
+ VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
gtl::InlinedVector bitmap(data_shape.dims(), false);
std::vector xla_axes;
int64 num_elements_reduced = 1LL;
for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
- int32 index = axes_literal.Get({i});
+ int64 index = axes[i];
OP_REQUIRES(ctx,
!(index < -data_shape.dims() || index >= data_shape.dims()),
errors::InvalidArgument("Invalid reduction dimension (", index,
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 64900e4709fd3e16d21096b0cfff8922906cb0d4..e172c649325adb6f7761ce0be141f21e8d545bc1 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel {
} else {
xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
+ DataType input_type = ctx->input_type(0);
+ XlaContext& tc = XlaContext::Get(ctx);
+
+ if (input_type == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ ctx->SetStatus(tc.AddResourceRetval(index_, resource));
+ return;
+ }
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
@@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel {
return;
}
- XlaContext& tc = XlaContext::Get(ctx);
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
@@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
-REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
+REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
+ RetvalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index 6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f..d9578eca5bf11110e9770b66a4dab82c597da6ee 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -67,7 +67,7 @@ class SelectOp : public XlaOpKernel {
// to get the dimensions in the right order.
const auto dim_sizes = then_shape.dim_sizes();
gtl::ArraySlice bdims = dim_sizes;
- bdims.pop_front();
+ bdims.remove_prefix(1);
cond_handle = xla::Broadcast(cond_handle, bdims);
std::vector dim_order(then_shape.dims());
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 025ba827410f1a9f993a8a1855558a2daa86609b..d6bd927135c013ac1ec3f6547aef358dc2741896 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace {
@@ -33,7 +33,7 @@ namespace {
class SoftmaxOp : public XlaOpKernel {
public:
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- log_ = str_util::StartsWith(type_string(), "Log");
+ log_ = absl::StartsWith(type_string(), "Log");
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..412afeaaad96842521fbd306f5b666e837e675fd
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaBroadcastHelperOp : public XlaOpKernel {
+ public:
+ explicit XlaBroadcastHelperOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp lhs = context->Input(0);
+ xla::XlaOp rhs = context->Input(1);
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims();
+ const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape;
+ const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape;
+
+ std::vector broadcast_dims;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims",
+ &broadcast_dims));
+ if (broadcast_dims.empty()) {
+ OP_REQUIRES(
+ context,
+ lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 ||
+ rhs_shape.dims() == 0,
+ errors::InvalidArgument(
+ "If broadcast_dims is empty, both "
+ "arguments must have equal rank; "
+ "argument shapes, or at least one argument must be a scalar: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, broadcast_dims.size() == min_rank_shape->dims(),
+ errors::InvalidArgument(
+ "broadcast_dims must have size equal to the smaller argument rank; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ std::vector sorted_broadcast_dims = broadcast_dims;
+ absl::c_sort(sorted_broadcast_dims);
+ std::set dims_set(broadcast_dims.begin(), broadcast_dims.end());
+ OP_REQUIRES(context,
+ dims_set.size() == broadcast_dims.size() &&
+ broadcast_dims == sorted_broadcast_dims,
+ errors::InvalidArgument(
+ "Duplicate or nonmonotonic dimension in broadcast_dims; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]"));
+
+ std::vector broadcast_shape(max_rank_shape->dims(), 1LL);
+ for (int i = 0; i < broadcast_dims.size(); ++i) {
+ const int dim = broadcast_dims[i];
+ OP_REQUIRES(
+ context, dim >= 0 && dim < broadcast_shape.size(),
+ errors::InvalidArgument(
+ "Invalid broadcast dimension (", dim, "); broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ broadcast_shape[dim] = min_rank_shape->dim_size(i);
+ }
+ xla::PrimitiveType type = context->input_xla_type(0);
+ xla::Shape broadcast_xla_shape =
+ xla::ShapeUtil::MakeShape(type, broadcast_shape);
+ if (broadcast_lhs) {
+ lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims);
+ } else {
+ rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims);
+ }
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp);
+};
+
+REGISTER_XLA_OP(
+ Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"),
+ XlaBroadcastHelperOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..8848623868091f8d19b1622f23ba23c68689d90d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaConvOp : public XlaOpKernel {
+ public:
+ explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+ const TensorShape padding_shape = context->InputShape("padding");
+ std::vector window_strides;
+ std::vector lhs_dilation;
+ std::vector rhs_dilation;
+ int64 feature_group_count;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation",
+ &lhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation",
+ &rhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(
+ "feature_group_count", &feature_group_count));
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get({i, 0}),
+ padding_literal.Get({i, 1})};
+ }
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::ConvGeneralDilated(
+ context->Input(0), context->Input(1), window_strides, padding,
+ lhs_dilation, rhs_dilation, dnums_, feature_group_count,
+ &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::ConvolutionDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
+};
+
+REGISTER_XLA_OP(Name("XlaConv")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("lhs_dilation")
+ .CompileTimeConstInput("rhs_dilation")
+ .CompileTimeConstInput("feature_group_count")
+ .CompileTimeConstInput("padding"),
+ XlaConvOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..2fed53e5c072e1a50e0f07f45357ee86c90f986f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaDotOp : public XlaOpKernel {
+ public:
+ explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
+ dnums_, &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
+};
+
+REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..59502d83c7338bd1b05b3323a97761fff2da186a
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaPadOp : public XlaOpKernel {
+ public:
+ explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape padding_value_shape =
+ context->InputShape("padding_value");
+
+ std::vector padding_low;
+ std::vector padding_high;
+ std::vector padding_interior;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low",
+ &padding_low));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high",
+ &padding_high));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "padding_interior", &padding_interior));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape),
+ errors::InvalidArgument("padding_value must be a scalar"));
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == padding_low.size(),
+ errors::InvalidArgument(
+ "The size of padding_low must be equal to the input "
+ "rank (",
+ padding_low.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_high.size(),
+ errors::InvalidArgument(
+ "The size of padding_high must be equal to the input "
+ "rank (",
+ padding_high.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_interior.size(),
+ errors::InvalidArgument(
+ "The size of padding_interior must be equal to the input "
+ "rank (",
+ padding_interior.size(), " vs. ", rank, ")"));
+
+ auto non_negative = [](int64 x) { return x >= 0; };
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_low, non_negative),
+ errors::InvalidArgument("padding_low must be non-negative, got [",
+ absl::StrJoin(padding_low, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_high, non_negative),
+ errors::InvalidArgument("padding_high must be non-negative, got [",
+ absl::StrJoin(padding_high, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_interior, non_negative),
+ errors::InvalidArgument("padding_interior must be non-negative, got [",
+ absl::StrJoin(padding_interior, ","), "]"));
+
+ xla::PaddingConfig padding_config;
+ for (int i = 0; i < rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ dim->set_edge_padding_low(padding_low[i]);
+ dim->set_edge_padding_high(padding_high[i]);
+ dim->set_interior_padding(padding_interior[i]);
+ }
+
+ xla::XlaOp output =
+ xla::Pad(context->Input("input"), context->Input("padding_value"),
+ padding_config);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp);
+};
+
+REGISTER_XLA_OP(Name("XlaPad")
+ .CompileTimeConstInput("padding_low")
+ .CompileTimeConstInput("padding_high")
+ .CompileTimeConstInput("padding_interior"),
+ XlaPadOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..fc2425f37bfa793ce3a106b635c9dffd15b975ff
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaReduceOp : public XlaOpKernel {
+ public:
+ explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_));
+ OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce",
+ &dimensions_to_reduce_));
+ std::set dims_set(dimensions_to_reduce_.begin(),
+ dimensions_to_reduce_.end());
+ OP_REQUIRES(
+ context, dims_set.size() == dimensions_to_reduce_.size(),
+ errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
+ "argument to XlaReduce"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape init_value_shape = context->InputShape("init_value");
+ const DataType dtype = context->input_type(0);
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
+ errors::InvalidArgument("init_value must be a scalar"));
+
+ auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
+ OP_REQUIRES(context,
+ rank >= dimensions_to_reduce_.size() &&
+ absl::c_all_of(dimensions_to_reduce_, dim_in_range),
+ errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce"));
+
+ // Build the reducer function.
+ XlaCompiler::Argument reducer_arg;
+ reducer_arg.kind = XlaCompiler::Argument::kParameter;
+ reducer_arg.type = dtype;
+ reducer_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult reducer;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *reducer_,
+ {reducer_arg, reducer_arg}, &reducer));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaReduce reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ xla::XlaOp output =
+ xla::Reduce(context->Input("input"), context->Input("init_value"),
+ *reducer.computation, dimensions_to_reduce_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* reducer_;
+ std::vector dimensions_to_reduce_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
+};
+
+REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..089776fcf74fcf6b363dfff5de8d86d7449eacd6
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
@@ -0,0 +1,147 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaSelectAndScatterOp : public XlaOpKernel {
+ public:
+ explicit XlaSelectAndScatterOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_));
+ OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const DataType dtype = context->input_type(0);
+
+ std::vector window_dimensions;
+ std::vector window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == window_dimensions.size(),
+ errors::InvalidArgument(
+ "The size of window_dimensions must be equal to the input "
+ "rank (",
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
+ errors::InvalidArgument(
+ "The size of window_strides must be equal to the input "
+ "rank (",
+ window_strides.size(), " vs. ", rank, ")"));
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
+
+ // Build the select function.
+ XlaCompiler::Argument select_arg;
+ select_arg.kind = XlaCompiler::Argument::kParameter;
+ select_arg.type = dtype;
+ select_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult select;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *select_computation_,
+ {select_arg, select_arg}, &select));
+
+ xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {});
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(select.xla_output_shape,
+ select_output_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaSelectAndScatter select. Expected ",
+ xla::ShapeUtil::HumanString(select_output_shape), " got ",
+ xla::ShapeUtil::HumanString(select.xla_output_shape)));
+
+ // Build the scatter function.
+ XlaCompiler::Argument scatter_arg;
+ scatter_arg.kind = XlaCompiler::Argument::kParameter;
+ scatter_arg.type = dtype;
+ scatter_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult scatter;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *scatter_computation_,
+ {scatter_arg, scatter_arg}, &scatter));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of scatter. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(scatter.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get({i, 0}),
+ padding_literal.Get({i, 1})};
+ }
+
+ xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding(
+ context->Input("operand"), *select.computation, window_dimensions,
+ window_strides, padding, context->Input("source"),
+ context->Input("init_value"), *scatter.computation);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* select_computation_;
+ const NameAttrList* scatter_computation_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp);
+};
+
+REGISTER_XLA_OP(Name("XlaSelectAndScatter")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ XlaSelectAndScatterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index cb7a40e23d539f758d963791f1c2b4d37374ade5..99511e991422014c877fb5f6b7fb6a914e730f40 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:lib",
],
@@ -78,8 +78,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
@@ -119,6 +119,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index f666d22ea44216beef74608bb4d9f33fb2fe82c6..d8c050d09e871c80e128989c9fbdb57c266b19ed 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -27,7 +27,8 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x, bool conjugate_y) {
+ bool transpose_y, bool conjugate_x, bool conjugate_y,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
@@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
y = xla::Conj(y);
}
+ xla::PrecisionConfigProto precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+
// If there are no batch dimensions, use a regular Dot.
// TODO(b/69062148) Remove this code when Dot emitters can be passed
// dimensions to transpose directly (i.e. without requiring a Transpose
@@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
if (batch_dimension_numbers.empty()) {
auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
- return xla::Dot(lhs, rhs);
+ return xla::Dot(lhs, rhs, &precision_proto);
}
xla::DotDimensionNumbers dot_dnums;
@@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- return xla::DotGeneral(x, y, dot_dnums);
+
+ return xla::DotGeneral(x, y, dot_dnums, &precision_proto);
});
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76..6cfccd55530ff40a309673d57d1fe61fc8264316 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -45,7 +45,9 @@ namespace tensorflow {
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
bool transpose_y = false, bool conjugate_x = false,
- bool conjugate_y = false);
+ bool conjugate_y = false,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::DEFAULT);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117..67fb56510cbd0677a2b78e2090f98b602539c6bd 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -49,7 +49,8 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
+xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -101,7 +102,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// np.dot(row, np.swapaxes(row, -1, -2))
auto diag_dot = BatchDot(row, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2)))
auto l_ii =
@@ -121,7 +123,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// r.T)
auto dot = BatchDot(body_l, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// np.dot(l[..., i+1:, :i], r.T)
auto dot_ip1 =
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
@@ -145,7 +148,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
} // namespace
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -181,14 +185,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
a = UpdateSliceInMinorDims(a, before - delta, {i, i});
}
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
- auto factorized = CholeskyUnblocked(x);
+ auto factorized = CholeskyUnblocked(x, precision);
l = UpdateSliceInMinorDims(l, factorized, {i, i});
if (i + k < n) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 1bef9bb166c576ec665bb48265b4da200ddca2a0..60cd7ded53fe862f29ca2bb68b175fcd1c89b70c 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -30,7 +30,9 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256);
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index fc0c1ee838190b1f1a7ca5b901c97e0a35232a97..b6f30d8d49bf05813fa6fccc4544b0631f866490 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -149,7 +149,8 @@ struct QRBlockResult {
xla::XlaOp taus; // Shape: [..., n]
xla::XlaOp vs; // Shape: [..., m, n]
};
-xla::StatusOr QRBlock(xla::XlaOp a) {
+xla::StatusOr QRBlock(
+ xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -190,8 +191,12 @@ xla::StatusOr QRBlock(xla::XlaOp a) {
auto v_broadcast = xla::Reshape(v, shape);
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, :]))
- auto vva = BatchDot(v_broadcast, a);
- vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true);
+ auto vva =
+ BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ vva =
+ BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a = a - xla::Mul(tau, vva,
/*broadcast_dimensions=*/batch_dim_indices);
@@ -251,7 +256,8 @@ xla::StatusOr QRBlock(xla::XlaOp a) {
// vs.
xla::StatusOr ComputeWYRepresentation(
xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs,
- xla::XlaOp taus, int64 m, int64 n) {
+ xla::XlaOp taus, int64 m, int64 n,
+ xla::PrecisionConfigProto::Precision precision) {
std::vector batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1;
@@ -272,9 +278,12 @@ xla::StatusOr ComputeWYRepresentation(
auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
// yv has shape [..., n, 1]
- auto yv = BatchDot(y, v, /*transpose_x=*/true);
+ auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
// wyv has shape [..., m, 1]
- auto wyv = BatchDot(w, yv);
+ auto wyv =
+ BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
auto z = xla::Mul(
-beta, v + wyv,
@@ -321,8 +330,9 @@ xla::StatusOr ComputeWYRepresentation(
// return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
-xla::StatusOr QRDecomposition(xla::XlaOp a,
- int64 block_size) {
+xla::StatusOr QRDecomposition(
+ xla::XlaOp a, int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -352,29 +362,36 @@ xla::StatusOr QRDecomposition(xla::XlaOp a,
int64 k = std::min(block_size, p - i);
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
- TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block));
+ TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision));
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
// Compute the I-WY block representation of a product of Householder
// matrices.
- TF_ASSIGN_OR_RETURN(auto w,
- ComputeWYRepresentation(type, batch_dims, qr_block.vs,
- qr_block.taus, m - i, k));
+ TF_ASSIGN_OR_RETURN(
+ auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
+ qr_block.taus, m - i, k, precision));
auto y = qr_block.vs;
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
- auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true);
- a_update = BatchDot(y, a_update);
+ auto a_update =
+ BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ a_update =
+ BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
- auto q_update = BatchDot(q_panel, w);
- q_update =
- BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true);
+ auto q_update =
+ BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ q_update = BatchDot(q_update, y, /*transpose_x=*/false,
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index abd2316ac961f583dd29f90f43cf6209de30bd6a..05565477b6062618a75f929b69c38938ddfd7a5a 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -32,8 +33,10 @@ struct QRDecompositionResult {
xla::XlaOp r;
};
-xla::StatusOr QRDecomposition(xla::XlaOp a,
- int64 block_size = 128);
+xla::StatusOr QRDecomposition(
+ xla::XlaOp a, int64 block_size = 128,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index ba22eff73abab11abeb57283c63318b2e50a9ca1..bafe5099f2d494fd3549fae41397ffc5a22f5cb7 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -58,7 +58,7 @@ xla::StatusOr XlaScatter(
") must be <= the rank of the buffer (shape: ",
xla::ShapeUtil::HumanString(buffer_shape), ")");
}
- indices_dims.pop_back();
+ indices_dims.remove_suffix(1);
}
int64 num_indices = 1;
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index febb638e5e8a87d78919f1eaa556d9c05ee40112..37b2240b45b4ae6a587c827cfdfa1096b4e1737e 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -110,8 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
});
}
-xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
- bool transpose_a, bool conjugate_a) {
+xla::XlaOp InvertDiagonalBlocks(
+ xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
// Input is a batch of square lower triangular square matrices. Its shape is
@@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
- auto update = -DotGeneral(input_row, body_out, dnums);
+ xla::PrecisionConfigProto precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+ auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, start_indices);
@@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
});
}
-xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
- xla::XlaOp inv_diag_blocks,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a) {
+xla::XlaOp SolveWithInvertedDiagonalBlocks(
+ xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
@@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
auto a_row =
MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
if (left_side) {
- remainder = b_row - BatchDot(a_row, x, transpose_a, false);
+ remainder = b_row - BatchDot(a_row, x, transpose_a, false,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
} else {
- remainder = b_row - BatchDot(x, a_row, false, transpose_a);
+ remainder = b_row - BatchDot(x, a_row, false, transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
}
}
@@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::ConstantR0WithType(builder, xla::S32, j * block_size);
std::vector update_starts = {start_index, zero};
if (left_side) {
- x_update = BatchDot(inv_block, remainder, transpose_a, false);
+ x_update =
+ BatchDot(inv_block, remainder, transpose_a, false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
} else {
- x_update = BatchDot(remainder, inv_block, false, transpose_a);
+ x_update =
+ BatchDot(remainder, inv_block, false, transpose_a,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
std::swap(update_starts[0], update_starts[1]);
}
x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
@@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size) {
+ int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
auto diag_blocks = DiagonalBlocks(a, block_size);
// We invert these blocks in parallel using batched matrix-vector products
- auto inv_diag_blocks =
- InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a);
+ auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
+ conjugate_a, precision);
// We now find the solution using GEMMs
- auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
- lower, transpose_a, conjugate_a);
+ auto x =
+ SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
+ transpose_a, conjugate_a, precision);
return x;
});
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 555760b7efabddfb25c9135b109a1c48b487415e..ac42a4835295b7cb52697710d738f4728d3983d1 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -59,7 +59,9 @@ namespace tensorflow {
// blocking is used.
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size = 128);
+ int64 block_size = 128,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index ace6fd1d8eeaf439509a7b75d8d986997c392e73..4dce0a2102cf9c782850ccc7af4f14b59bd51e53 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -11,6 +11,8 @@ cc_library(
srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index a59c77f5c3a309abe8f6fbab1e48455d54e8fae5..2cd9ae799f06afdcbae5429ef8caffd3b4d29c29 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -13,11 +13,97 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace {
+
+// Helper shape function for operators that return an output with the same rank
+// as their first input.
+Status UnchangedRank(shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("XlaBroadcastHelper")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("broadcast_dims: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("lhs_output: T")
+ .Output("rhs_output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Helper operator for performing XLA-style broadcasts
+
+Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
+whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
+for binary operators.
+
+lhs: the LHS input tensor
+rhs: the RHS input tensor
+broadcast_dims: an XLA-style broadcast dimension specification
+lhs_output: the broadcasted LHS tensor
+rhs_output: the broadcasted RHS tensor
+)doc");
+
+REGISTER_OP("XlaConv")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("lhs_dilation: Tindices")
+ .Input("rhs_dilation: Tindices")
+ .Input("feature_group_count: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+
+lhs: the input tensor
+rhs: the kernel tensor
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+lhs_dilation: dilation to apply between input elements
+rhs_dilation: dilation to apply between kernel elements
+feature_group_count: number of feature groups for grouped convolution.
+dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
+
+REGISTER_OP("XlaDot")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Attr("T: numbertype")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+
+lhs: the LHS tensor
+rhs: the RHS tensor
+dimension_numbers: a serialized xla::DotDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
@@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
+REGISTER_OP("XlaPad")
+ .Input("input: T")
+ .Input("padding_value: T")
+ .Input("padding_low: Tindices")
+ .Input("padding_high: Tindices")
+ .Input("padding_interior: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Pad operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#pad
+.
+
+input: A `Tensor` of type T.
+padding_value: A scalar `Tensor` of type T.
+padding_low: the padding to apply at the start of each input dimensions
+padding_high: the padding to apply at the end of each input dimension.
+padding_interior: the padding to apply between each input element.
+output: A `Tensor` of type T.
+)doc");
+
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
@@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
+REGISTER_OP("XlaReduce")
+ .Input("input: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("dimensions_to_reduce: list(int)")
+ .Attr("reducer: func")
+ .Output("output: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ int rank = c->Rank(c->input(0));
+ std::vector