diff --git a/RELEASE.md b/RELEASE.md
index 2717c75740aeea7821fb6c57dfc85908e86e9d51..84d9d52868ecd55d38d6073315749d11c2340e8c 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -6,7 +6,7 @@
* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor.
* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability.
* `tf.contrib.bayesflow` is moving out to it's own repo.
-* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication.
+* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication[1](#rpc-issue).
## Bug Fixes and Other Changes
* `tf.data`:
@@ -49,13 +49,14 @@
* Fix non-uniformity of orthogonal matrices.
* Fix bug where multi-image Estimator eval summaries were not displayed correctly.
+1 The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release.
+
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu
-
# Release 1.7.0
## Major Features And Improvements
@@ -235,7 +236,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
* Add `complex64` support to XLA compiler.
* `bfloat` support is now added to XLA infrastructure.
* Make `ClusterSpec` propagation work with XLA devices.
- * Use a determinisitic executor to generate XLA graph.
+ * Use a deterministic executor to generate XLA graph.
* `tf.contrib`:
* `tf.contrib.distributions`:
* Add `tf.contrib.distributions.Autoregressive`.
diff --git a/SECURITY.md b/SECURITY.md
index a5ce3a62ee202f6e7d83f0fedc2777d9c88ba9b5..01886b613e5d93793953124331b57f075fe7a373 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -173,7 +173,7 @@ the progress being made towards a fix and announcement.
In addition, please include the following information along with your report:
* Your name and affiliation (if any).
-* A description the technical details of the vulnerabilities. It is very
+* A description of the technical details of the vulnerabilities. It is very
important to let us know how we can reproduce your findings.
* An explanation who can exploit this vulnerability, and what they gain when
doing so -- write an attack scenario. This will help us evaluate your report
diff --git a/configure.py b/configure.py
index b745e374a2baaffec73f9f9382e1bab322e7f0fd..6d9aba61bbc73ba1b80321d6859877c371dc5427 100644
--- a/configure.py
+++ b/configure.py
@@ -845,8 +845,8 @@ def reformat_version_sequence(version_str, sequence_count):
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
- 'Please specify the CUDA SDK version you want to use, '
- 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
+ 'Please specify the CUDA SDK version you want to use. '
+ '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
@@ -1226,6 +1226,9 @@ def set_tf_cuda_compute_capabilities(environ_cp):
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
# Check whether all capabilities from the input is valid
all_valid = True
+ # Remove all whitespace characters before splitting the string
+ # that users may insert by accident, as this will result in error
+ tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m:
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 18eeb2816807ec9986999cfc2c9a4c0f032683c0..b86b277ac3200b88ae03490a6c1b64d464e81950 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
- tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
+ tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 82dbd3cdbc6e8fb0c6fbcddb33b6a95c87a83225..95b04f9058afdfaadbc24f0238860279fcd3e800 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
}
return ret;
}
+
+void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Tensor* tensor, TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << "Enqueuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ tensorflow::Tensor internal_tensor;
+ if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
+ VLOG(1) << "Enqueu'ing tensor content: "
+ << internal_tensor.DebugString();
+ }
+ }
+ }
+
+ TF_Operation* enqueue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
+ if (enqueue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the enqueue node in the TF graph.");
+ return;
+ }
+
+ TF_Operation* placeholder_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
+ if (placeholder_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the placeholder node as input to enqueue in the TF "
+ "graph.");
+ return;
+ }
+
+ VLOG(1) << "Running the enqueue op";
+ TF_Output input{placeholder_op, 0};
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
+ // output related parameters
+ /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
+ /*targets*/ &enqueue_op, /*ntargets*/ 1,
+ /*run_metadata*/ nullptr, status);
+ VLOG(1) << "Enqueuing is done.";
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index e6757c065fc540fa789cdbb694e66ca0b00c4832..20bdace40f1272ded06e710034053a7610326e7f 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -87,8 +87,11 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
unsigned char is_mnist, TF_Status* status);
// On success, dequeues a tensor from a TF-managed FifoQueue given by
-// `tensor_id`, associated with `session`. Caller must call TF_DeleteTensor()
-// over the returned tensor. If the queue is empty, this call is blocked.
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_dequeue_", to be executed by this API call.
+
+// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
+// empty, this call is blocked.
//
// Tensors are enqueued via the corresponding TF enqueue op.
// TODO(hongm): Add support for `timeout_ms`.
@@ -96,6 +99,22 @@ TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
int tensor_id,
TF_Status* status);
+// On success, enqueues `tensor` into a TF-managed FifoQueue given by
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_enqueue_", to be executed by this API call. It reads
+// from a placeholder node "arg_tensor_enqueue_".
+//
+// `tensor` is still owned by the caller. This call will be blocked if the queue
+// has reached its capacity, and will be unblocked when the queued tensors again
+// drop below the capacity due to dequeuing.
+//
+// Tensors are dequeued via the corresponding TF dequeue op.
+// TODO(hongm): Add support for `timeout_ms`.
+TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TF_Tensor* tensor,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 9b86425aa5fbc2be2872b3f5d2809eaa844f9d68..577f10c5e69ea9ecbe8ce821c6bd5167e98bef25 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) {
}
const tensorflow::string input_op_name =
- tensorflow::ParseTensorName(input_name).first.ToString();
+ std::string(tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
@@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
const tensorflow::string output_op_name =
- tensorflow::ParseTensorName(output_name).first.ToString();
+ std::string(tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index b1f7bdaa5420a56386e6983052df20aa976aa867..74bc25a491ac01cb725d1c004197e48727c30230 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
- v2_reader_->key().ToString() /* full var's name */,
+ std::string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
@@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
- if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
+ if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
- string key = v2_reader_->key().ToString();
+ string key = std::string(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 97c323b87228039ba10f4ed5e434aa83621b1220..e9ed3395c448305bcd6317b0b292b4e4e0b659b1 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,13 +130,15 @@ class GradientTape {
}
}
- bool ShouldRecord(gtl::ArraySlice tensor_ids);
+ bool ShouldRecord(gtl::ArraySlice tensor_ids,
+ gtl::ArraySlice dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice output_tensors,
gtl::ArraySlice input_tensor_id,
+ gtl::ArraySlice input_dtypes,
BackwardFunction* backward_function,
const std::function& backward_function_deleter);
@@ -170,12 +172,30 @@ class GradientTape {
// Template instantiations here
+inline bool IsDtypeTrainable(DataType dtype) {
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_COMPLEX64:
+ case DT_COMPLEX128:
+ case DT_RESOURCE:
+ case DT_VARIANT:
+ return true;
+ default:
+ return false;
+ }
+}
+
template
bool GradientTape::ShouldRecord(
- gtl::ArraySlice tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
+ gtl::ArraySlice tensor_ids,
+ gtl::ArraySlice dtypes) {
+ CHECK_EQ(tensor_ids.size(), dtypes.size());
+ for (int i = 0; i < tensor_ids.size(); ++i) {
+ if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+ return IsDtypeTrainable(dtypes[i]);
}
}
return false;
@@ -189,9 +209,11 @@ void GradientTape::Watch(int64 tensor_id) {
template
void GradientTape::RecordOperation(
const string& op_type, gtl::ArraySlice output_tensors,
- gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function,
+ gtl::ArraySlice input_tensor_id,
+ gtl::ArraySlice input_dtypes,
+ BackwardFunction* backward_function,
const std::function& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
+ if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
@@ -380,49 +402,39 @@ Status InitialGradients(const VSpace& vspace,
gtl::ArraySlice output_gradients,
const TensorTape& tensor_tape,
const OpTape& op_tape,
- const gtl::FlatMap& tensor_usage_counts,
gtl::FlatMap>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "failed to find operation producing a tensor");
- }
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "none of operations outputs match expected tensor");
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "failed to find operation producing a tensor");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
}
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "none of operations outputs match expected tensor");
}
} else {
- (*result)[id].push_back(output_gradients[i]);
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
}
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
@@ -451,8 +463,7 @@ Status GradientTape::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
+ tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d73121c7b701ec06c03836d1a765f4b35d88fe92..d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
- return name.ToString();
+ return std::string(name);
}
void InferArgAttributes(const OpDef::ArgDef& arg,
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index c143b978338815ebc7134eb0a07867c5d8b13dca..62a889181e787f2e181135ab0563c45e1bab8812 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -220,7 +220,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
- current_constraints.insert(s.ToString());
+ current_constraints.insert(std::string(s));
}
}
} else {
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 31044ff85d6f0d72b34d03669fe508866d7d3358..bbc35da2ef6d14ff0d3570ef2d5cf6743456c674 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -44,7 +44,7 @@ namespace {
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
- const xla::Computation& computation,
+ const xla::XlaComputation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
@@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
- xla::CompileOnlyClient::AotComputationInstance instance;
+ xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
@@ -93,14 +93,14 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
- xla::Computation computation;
+ xla::XlaComputation computation;
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
if (!flags.out_session_module.empty()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr module,
+ TF_ASSIGN_OR_RETURN(std::unique_ptr module,
computation.Snapshot());
- // Serialize the SessionModule deterministically so that all the outputs of
- // a tf_library genrule are deterministic.
+ // Serialize the HloSnapshot deterministically so that all the outputs of a
+ // tf_library genrule are deterministic.
string proto;
TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 222e26810ac1157152ea81a56749b6652aa1f137..fd2cf2b67d4618dd626b8eef78eed044d7fde0a4 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
+ ":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
+ "test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@@ -118,6 +120,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tfcond",
+ testonly = 1,
+ config = "test_graph_tfcond.config.pbtxt",
+ cpp_class = "CondComp",
+ graph = "test_graph_tfcond.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_library(
name = "test_graph_tffunction",
testonly = 1,
@@ -194,6 +207,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
+ ":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 67767f55dae9b15aafbd8b129328bde2c59a9ef3..9ec7df163b1425f917e9ec51559efad3e6f05e75 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())
+def tfassert_eq(_):
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ control_flow_ops.Assert(
+ math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+ math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
+def tfcond(_):
+ p = array_ops.placeholder(dtypes.bool, name='p_hold')
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ z = control_flow_ops.cond(p, lambda: x, lambda: y)
+ array_ops.identity(z, name='result')
+
+
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
@@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')
-def tfassert_eq(_):
- x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = array_ops.placeholder(dtypes.int32, name='y_hold')
- control_flow_ops.Assert(
- math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
- math_ops.add(x, math_ops.negative(y), name='x_y_diff')
-
-
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
+ write_graph(tfassert_eq, FLAGS.out_dir)
+ write_graph(tfcond, FLAGS.out_dir)
+ write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
- write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
- write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..94a01ad4abfaab5e4b087b7cc219e86c1d0179b8
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
@@ -0,0 +1,20 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "p_hold" }
+ shape {}
+}
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "result" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index aa9d968265b4619ff2e3c910e3d7455ae07bc49d..309a991fc11ab74ddd58a6345d9d40ad84fb2734 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
+TEST(TFCompileTest, Cond) {
+ CondComp cond;
+ EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
+ EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
+ EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
+ cond.arg1() = 10;
+ cond.arg2() = 20;
+ {
+ cond.arg0() = true;
+ const int32 expected_result = cond.arg1();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+ {
+ cond.arg0() = false;
+ const int32 expected_result = cond.arg2();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+}
+
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
@@ -525,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%arg1.0.1)");
auto add_profile_line = HasSubstr(
- "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
- "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
- "f32[2,2]{1,0} %add)");
- auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
+ "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
+ "%dot.0.2, f32[2,2]{1,0} %add.0.5)");
+ auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
hlo_profile_lines.end());
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index af2965bba5b91a66e206f05bb8945b0dcde1d2b4..a6b3ce394c6859c4f45bbde4e39dde9229da3388 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -261,6 +261,7 @@ cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
+ "create_xla_launch_op.h",
],
deps = [
":common",
@@ -270,6 +271,29 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "create_xla_launch_op_test",
+ srcs = [
+ "create_xla_launch_op.h",
+ "create_xla_launch_op_test.cc",
+ ],
+ deps = [
+ ":create_xla_launch_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
@@ -360,6 +384,31 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "xla_launch_util_test",
+ size = "small",
+ srcs = ["xla_launch_util_test.cc"],
+ deps = [
+ ":common",
+ ":xla_compilation_cache",
+ ":xla_launch_util",
+ ":xla_tensor",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core/kernels:variable_ops",
+ ],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 18d901323f108505979be484c2bfad5998ab0748..f35e916eb937faf7e1afd53a4a5dfdb95a8bbe43 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -25,78 +27,189 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
-// 'ndef' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
+// Utility which searches for values in a sorted list by scanning over it once.
+// No matter how many times ScanForValue is called, the list is scanned at most
+// once. However, if a call to ScanForValue skips over a value, that value is
+// not revisited in future calls to ScanForValue, so callers must take
+// care to order their calls.
//
-// This routine is here so that FunctionLibraryRuntime can jit a
-// specific function call as requested.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
- std::unique_ptr* kernel) {
- bool xla_compile = false;
- if (!flr->GetFunctionLibraryDefinition()
- ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
- .ok() ||
- !xla_compile) {
- // Not marked as _XlaCompile=true.
- return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+// Useful for merging multiple sorted lists in O(n) time.
+class SinglePassSearch {
+ public:
+ // Creates a SinglePassSearch object that can be used to search in `values`.
+ // Does not take ownership of `values`. `values` must outlive this.
+ // `values` must be sorted.
+ explicit SinglePassSearch(const std::vector* values)
+ : current_index_(0), values_(values) {}
+
+ // Scans forward in the vector looking for "value", updating the internal
+ // position in to the vector.
+ // Returns true iff the vector contains the given value at or after current
+ // position.
+ // Not thread-safe.
+ bool ScanForValue(int value) {
+ while (current_index_ < values_->size() &&
+ (*values_)[current_index_] <= value) {
+ if ((*values_)[current_index_] == value) {
+ current_index_++;
+ return true;
+ }
+ current_index_++;
+ }
+ return false;
}
- // Make sure that kernels have been registered on the JIT device.
- XlaOpRegistry::RegisterCompilationKernels();
- if (!IsCompilable(flr, ndef)) {
- // ndef is calling a function that XLA can't compile.
- return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+
+ private:
+ int current_index_;
+ const std::vector* values_;
+};
+
+Status CompilationRequested(const FunctionLibraryRuntime& flr,
+ const NodeDef& node_def) {
+ bool xla_compile = false;
+ // Check if op is marked _XlaCompile=true.
+ Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+ node_def, kXlaCompileAttr, &xla_compile);
+ if (!status.ok() || !xla_compile) {
+ if (VLOG_IS_ON(3)) {
+ if (!status.ok()) {
+ VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+ << node_def.op() << ". status=" << status.ToString();
+ } else {
+ VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+ }
+ }
+ return Status(error::INVALID_ARGUMENT, "");
}
+ return Status::OK();
+}
+
+// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
+// runtime, returns this function's body in `fbody` as well as the indices
+// of its constant and resource arguments.
+// `fbody` is owned by `flr`.
+// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
+// They are sorted in ascending order on this function's return.
+Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
+ const NodeDef& node_def,
+ const FunctionBody** fbody,
+ std::vector* constant_arg_indices,
+ std::vector* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
- // If ndef is not instantiable, e.g., the function does not exist,
+ // If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
- flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
- const FunctionBody* fbody = flr->GetFunctionBody(handle);
- CHECK(fbody); // Can't be nullptr since we just instantiated it.
- std::vector const_args(fbody->arg_types.size());
+ flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+ *fbody = flr->GetFunctionBody(handle);
+ CHECK(*fbody); // Can't be nullptr since we just instantiated it.
+ const DataTypeVector& arg_types = (*fbody)->arg_types;
+ std::vector const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
- // There is a const arg. Bail out.
- return errors::InvalidArgument("Const arg: ", i, " in ",
- DebugString(fbody->fdef));
+ constant_arg_indices->push_back(i);
+ }
+ }
+
+ // There can be hundreds of resource variables. Reserve the space for them.
+ // We don't reserve for constants above as they are usually few.
+ resource_arg_indices->reserve(arg_types.size());
+ for (int i = 0; i < arg_types.size(); ++i) {
+ if (arg_types[i] == DT_RESOURCE) {
+ resource_arg_indices->push_back(i);
}
}
- NodeDef launch_def;
- launch_def.set_name(ndef.name());
- launch_def.set_op("_XlaLaunch");
- launch_def.set_device(flr->device()->name());
- AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
- AddNodeAttr("Nresources", 0, &launch_def);
- AddNodeAttr("Targs", fbody->arg_types, &launch_def);
- AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
- NameAttrList func;
- func.set_name(ndef.op());
- *(func.mutable_attr()) = ndef.attr();
- AddNodeAttr("function", func, &launch_def);
-
- // TODO(b/32387911): Handles the host memory types across function
- // calls properly. For now, we assume all inputs and outputs are on
- // the device memory.
+ return Status::OK();
+}
+
+} // namespace
+
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr* kernel) {
+ TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
+
+ VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
+
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterCompilationKernels();
+ if (!IsCompilable(flr, node_def)) {
+ // node_def is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ",
+ node_def.ShortDebugString());
+ }
+
+ // Get function body, constant args, and resource args.
+ const FunctionBody* fbody = nullptr;
+ std::vector constant_arg_indices;
+ std::vector resource_arg_indices;
+ TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
+ flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
+
+ // Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+ // These indices are used only for optimization purposes. They allow us
+ // to loop over constant_arg_indices and resource_arg_indices only once
+ // while iterating over all the function arguments checking if it is a
+ // resource or a constant.
+ // The reason we optimized this code is because functions can have a lot of
+ // captured arguments. For example, the backward pass of ResNet50 takes in all
+ // 214 variables and a similar number of activations.
+ SinglePassSearch constants_search(&constant_arg_indices);
+ SinglePassSearch resources_search(&resource_arg_indices);
+ for (int i = 0; i < fbody->arg_types.size(); ++i) {
+ if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
+ // Compile-time constants and resource handles are expected to be in
+ // host memory.
+ input_memory_types[i] = HOST_MEMORY;
+ }
+ }
+ // One might wonder, about the case where a compile-time constant argument
+ // (which must be in host memory) is also used as an input into an op,
+ // e.g. Add, that expects its inputs in device memory. Here is how it
+ // works now.
+ // First, what do we mean by "op expects an input in XYZ memory"?
+ // There are two types of "ops" here: the tf2xla kernel and the HLO
+ // computation it builds. The tf2xla kernel needs to retrieve the actual
+ // numeric value of the compile-time constant tensors, so it really expects
+ // them to be on in host memory. However, for other inputs, it refers to them
+ // using xla::ComputationDataHandle, which is just a symbolic handle that
+ // xla::ComputationBuilder assigns. How does this handle gets assigned for
+ // constant arguments? Even constant arguments get an _Arg node in the graph
+ // instatiated for Function compilation. The tf2xla kernel for constant _Arg
+ // nodes takes the constant value, converts it to XlaLiteral, and feeds it
+ // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
+ // constant XlaLiteral is included in the HLO graph, and subsequently, in
+ // the actual executable, which is copied to the device before being
+ // executed. Thus, when this executable runs, the constant is available in
+ // device memory.
+
+ // XlaLaunch kernel keeps all outputs (including constants, which it copies),
+ // in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ // Create the kernel.
+ NameAttrList function;
+ function.set_name(node_def.op());
+ *(function.mutable_attr()) = node_def.attr();
+
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
- dev->GetAllocator(AllocatorAttributes()), &launch_def,
+ dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- kernel->reset(new XlaLocalLaunchOp(&construction));
+
+ *kernel = absl::make_unique(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
+namespace {
+
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..98a22e351532c197c69c5ea908305d885fd2c9d0
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FunctionLibraryRuntime;
+class OpKernel;
+
+// Given a NodeDef 'node_def' and the function library runtime 'flr', if
+// 'node_def' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr* kernel);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..bcd5e75c7e4c021a9be874ed96e994768bb80811
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+// Create a FunctionDef that takes one resource and one regular param
+FunctionDef XTimesY() {
+ return FunctionDefHelper::Define(
+ // Name
+ "XTimesY",
+ // Args
+ {"x: float", "y: resource"},
+ // Return values
+ {"z: float"},
+ // Attr def
+ {},
+ // Nodes
+ {
+ {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
+ {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
+ });
+}
+
+class CreateXlaLaunchOpTest : public ::testing::Test {
+ protected:
+ void Init(const std::vector& flib) {
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ TF_CHECK_OK(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices_));
+
+ FunctionDefLibrary proto;
+ for (const auto& fdef : flib) {
+ *(proto.add_function()) = fdef;
+ }
+ lib_def_ = absl::make_unique(
+ OpRegistry::Global(), proto);
+ OptimizerOptions opts;
+ device_mgr_ = absl::make_unique(devices_);
+ pflr_ = absl::make_unique(
+ device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
+ flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+ }
+
+ FunctionLibraryRuntime* flr_;
+ std::vector devices_;
+ std::unique_ptr device_mgr_;
+ std::unique_ptr lib_def_;
+ std::unique_ptr pflr_;
+
+ std::unique_ptr kernel_;
+};
+
+AttrValue BoolAttr(bool b) {
+ AttrValue v;
+ v.set_b(b);
+ return v;
+}
+
+TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(
+ flr_, ToNodeDef(R"pb(
+ name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
+ )pb"), &kernel_);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ EXPECT_EQ("XTimesY", kernel_->name());
+ EXPECT_EQ("XTimesY", kernel_->type_string());
+
+ EXPECT_EQ(2, kernel_->num_inputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
+ EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
+ EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
+
+ EXPECT_EQ(1, kernel_->num_outputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
+ FunctionDef fdef = XTimesY();
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f06debaf316c0172a5683e56aa5de6ebb83fbece..6d1e3325ebd35b9608ea273fb7de39bad381e60d 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -240,7 +240,7 @@ class Encapsulator {
// Once edges between compiled and outside_compilation clusters have been
// replaced by send/recv ops, some dependencies may no longer be apparent.
// A clustering pass finds all the dependencies between HC nodes that are only
- // present as a result of edges between nodes in outside_compilaton clusters.
+ // present as a result of edges between nodes in outside_compilation clusters.
// Suppose there is a path from outside_compilation cluster C in subgraph S
// to outside_compilation cluster D in subgraph T. If S != T then a control
// edge is added from the call node for S to the call node for T, which
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 049d170fa48928474b894f2d0e1f2243c5f87275..86a9fd3b8e124e581bc4b73f264dbd5be46c790a 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -39,15 +39,15 @@ limitations under the License.
namespace tensorflow {
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), device_type_(ctx->device_type()) {
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
- function_ = *func;
- DataTypeVector constant_types;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
- num_constant_args_ = constant_types.size();
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector& constants,
+ const std::vector& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ device_type_(ctx->device_type()),
+ function_(function) {
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
}
}
-Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
+Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
-void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOp::Compute "
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
std::map variables =
- SnapshotResourceVariables(ctx, num_resource_args_);
+ SnapshotResourceVariables(ctx, resources_);
xla::LocalClient* client = static_cast(cache->client());
@@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map constant_args;
- for (int i = 0; i < num_constant_args_; ++i) {
+ for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(
- num_resource_args_, client, xla_allocator, allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(client, xla_allocator,
+ allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
@@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Done";
}
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 8f8e646f0ff6d94dfdf56721cacfce7fa658beb6..8dfc4b382d51151b6383fe7dd75429f3124d39be 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -26,6 +26,41 @@ limitations under the License.
namespace tensorflow {
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector& constants,
+ const std::vector& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache);
+
+ // Indexes of compile-time constant inputs
+ std::vector constants_;
+ // Indexes of resource inputs
+ std::vector resources_;
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ se::Platform::Id platform_id_;
+};
+
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
@@ -35,26 +70,12 @@ namespace tensorflow {
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
-class XlaLocalLaunchOp : public OpKernel {
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
- void Compute(OpKernelContext* ctx) override;
-
private:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** compiler);
-
- DeviceType device_type_;
- NameAttrList function_;
- int num_constant_args_;
- // Number of resource variable arguments.
- int num_resource_args_;
-
- se::Platform::Id platform_id_;
-
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 60458f6f3314b2c3b65be1c90e051b2a670383bc..6b83cf67ffc571f235ae84d0de58254c5d7e4962 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map variables = GetVariables(ctx);
- int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- num_resource_args, client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 2a7f04271d4b7ea330f32b88ea1e3f4037988a91..0223f97a032cf9efe56005248ce65d412e340b78 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-std::map SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
+std::map SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector& variables) {
std::map snapshot;
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
+ for (int i : variables) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- OptionalTensor& tensor = snapshot[first_variable + i];
+ ResourceHandle handle = HandleFromInput(ctx, i);
+ OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@@ -77,16 +76,16 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
return Status::OK();
}
-namespace {
+namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
ScopedShapedBuffer ExtractSubShapedBuffer(
ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator) {
- xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape(
+ const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_host_shape(), index);
- xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape(
+ const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_device_shape(), index);
ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
@@ -98,20 +97,23 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
sub_shape_tree.CopySubtreeFrom(shape_tree,
/*source_base_index=*/{index},
/*target_base_index=*/{});
- for (auto& index_to_buffer : shape_tree) {
- if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) {
- index_to_buffer.second = se::DeviceMemoryBase(nullptr, 0);
- }
- }
+ shape_tree.ForEachMutableElement(
+ [index](const xla::ShapeIndex& shape_index,
+ tensorflow::se::DeviceMemoryBase* data) {
+ // shape_index is empty for the root node. Ignore that.
+ if (!shape_index.empty() && shape_index[0] == index) {
+ *data = tensorflow::se::DeviceMemoryBase(nullptr, 0);
+ }
+ });
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
}
-} // namespace
+} // namespace internal
+using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- int64 num_resource_args, xla::LocalClient* client,
- xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
- : num_resource_args_(num_resource_args),
- client_(client),
+ xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+ bool allocate_xla_tensors)
+ : client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 8a6ff3b0c751206d184da63ef1a36e750a1252a5..a2431253f8c44bdd9b99a253f79bdb14722d7c72 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -31,15 +31,17 @@ limitations under the License.
namespace tensorflow {
class XlaAllocator;
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, whose
+// indices are specified in `variables` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
-// Returns a map of TensorFlow argument index to resource variable.
-std::map SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables);
+// Returns a map of TensorFlow argument index to resource variable. If a
+// resource variable is not initialized, the corresponding OptionalTensor
+// will have its `present` field set to false.
+std::map SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector& variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -72,7 +74,7 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
- XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+ XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
@@ -92,7 +94,6 @@ class XlaComputationLaunchContext {
const std::vector& arguments() const { return arg_ptrs_; }
private:
- int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
@@ -140,6 +141,17 @@ class XlaTensorBuffer : public TensorBuffer {
Allocator* allocator_;
};
+// Exposed in this header file for microbenchmarking purposes, but this is an
+// internal implementation detail.
+namespace internal {
+// Return the 'index''th subtree of the given ShapedBuffer as a
+// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
+// subtree, and sets the input's buffer pointers to nullptr for the subtree.
+xla::ScopedShapedBuffer ExtractSubShapedBuffer(
+ xla::ShapedBuffer* shaped_buffer, int index,
+ xla::DeviceMemoryAllocator* allocator);
+} // namespace internal
+
} // namespace tensorflow
#endif
diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..27813efc0bc0aecdbea2dfce5ca27ba704ea45e2
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_launch_util_test.cc
@@ -0,0 +1,64 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains microbenchmarks for performance critical functions in
+// xla_launch_util.cc.
+
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs
+// (cardinality of each non-leaf node's children).
+void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
+ tensorflow::testing::StopTiming();
+ xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
+ for (int i = 0; i < depth; ++i) {
+ std::vector shapes(fan_out, shape);
+ shape = xla::ShapeUtil::MakeTupleShape(shapes);
+ }
+ xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr,
+ /*device_ordinal=*/0);
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ // Extract a buffer from approximately the middle of the first level of the
+ // tree.
+ tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
+ /*index=*/fan_out / 2,
+ /*allocator=*/nullptr)
+ .release();
+ }
+}
+
+BENCHMARK(BM_ExtractSubBuffer)
+ ->ArgPair(1, 4)
+ ->ArgPair(1, 8)
+ ->ArgPair(1, 32)
+ ->ArgPair(1, 64)
+ ->ArgPair(1, 128)
+ ->ArgPair(1, 256)
+ ->ArgPair(1, 512)
+ ->ArgPair(2, 4)
+ ->ArgPair(2, 8)
+ ->ArgPair(2, 32)
+ ->ArgPair(2, 64)
+ ->ArgPair(2, 128);
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ tensorflow::testing::RunBenchmarks();
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 922a91897312096e4bb6ee2a1cc153e0039e2c7a..6b29c82ec11e39ad525663991e179443c2b6dca7 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -54,7 +54,7 @@ class XlaTensor {
// Some Tensors can have complex on-device shapes, including tuple shapes. To
// manage the memory for these tensors a ShapedBuffer may be required.
- // Return true if this TensorInfo contains a ShapedBuffer.
+ // Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
@@ -62,7 +62,7 @@ class XlaTensor {
CHECK(has_shaped_buffer());
return *shaped_buffer_;
}
- // Mutates the TensorInfo to set the ShapedBuffer.
+ // Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
xla::MakeUnique(std::move(shaped_buffer));
@@ -72,7 +72,7 @@ class XlaTensor {
// in on-demand mode to avoid re-copying values from the device if we know the
// host value already.
- // Return true if this TensorInfo contains a host tensor.
+ // Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index a94b298f87832057c6ec86a1ea250a54ed1b4ee0..9791792f29ca05f4ece77cca6305ed05343d1d38 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -300,6 +300,10 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
+ tags = [
+ "manual",
+ "notap",
+ ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -323,7 +327,11 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:function",
],
)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index bdd0185dfe4abe9d9acecc5381ff82c54b8c0705..5ab1585f8c6e07d6e3f0f40c99840b176492e523 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.layers import convolutional
+from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
@@ -43,7 +49,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.int32)
+ empty = constant_op.constant([], dtype=dtypes.float32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
@@ -51,7 +57,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
@@ -60,7 +66,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
@@ -131,7 +137,105 @@ class EagerTest(XLATestCase):
self.assertEqual(2., grads[0][0].numpy())
-if __name__ == "__main__":
+class EagerFunctionTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_scope():
+ matmul = function.defun(math_ops.matmul, compiled=True)
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t, transpose_a=True)
+ self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
+
+ def testConv(self):
+ if 'GPU' in self.device:
+ # TODO(b/32333178)
+ self.skipTest('Current implementation of RandomStandardNormal kernel '
+ 'is very slow on GPU, and has been blacklisted.')
+ with self.test_scope():
+ data_format = 'channels_last'
+ conv = convolutional.Conv2D(
+ filters=1, kernel_size=2, padding='VALID',
+ data_format=data_format, activation=nn_ops.relu,
+ kernel_initializer=init_ops.ones_initializer(),
+ bias_initializer=init_ops.zeros_initializer())
+ pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
+
+ def model(x):
+ x = conv(x)
+ return pool(x)
+ model = function.defun(model, compiled=True)
+
+ x = array_ops.ones([1, 4, 4, 1])
+ y = model(x)
+ self.assertAllEqual(y.numpy(), [[[[4.]]]])
+
+ def testReadVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun(compiled=True)
+ def f():
+ return v.read_value()
+
+ var = f()
+ self.assertEqual(1.0, var.numpy())
+
+ def testUpdateVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def f(v):
+ v.assign_add(1.0)
+ return v
+
+ f = function.defun(f, compiled=True)
+
+ var = f(v)
+ self.assertEqual(2.0, var.numpy())
+
+ def testAllArgumentKinds(self):
+ """Test a complex function that takes different argument kinds.
+
+ tf2xla machinery that translates, compiles, and runs defuns
+ classifies arguments into: compile-time constants, regular tensors,
+ and resources. This test creates a function with a mix of all these
+ kinds. Moreover, the order of function arguments is intentionally mixed up.
+
+ This also tests the case when the same argument is a compile-time constant
+ as well as used in an operation that normally expects its inputs to be
+ in device memory - addition in this case.
+ """
+ with self.test_scope():
+ def foo(c1, r1, v1, c2, v2, r2):
+ # c1 and c2 are compile-time constants
+ # r1 and r2 are regular tensors
+ # v1 and v2 are resource variables
+ a = c1 + r1
+ b = math_ops.cast(c2, dtypes.float32) + v2
+ c = array_ops.slice(v1, c1, c2)
+ d = r2 * v2
+ return a, b, c, d
+
+ foo = function.defun(foo, compiled=True)
+
+ c1 = [0, 0]
+ c2 = array_ops.ones([2], dtype=dtypes.int32)
+
+ r1 = array_ops.ones([2])
+ r2 = [[2., 2.], [3., 3.]]
+
+ v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
+ v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
+
+ a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
+
+ self.assertAllEqual([1, 1], a.numpy())
+ self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
+ self.assertAllEqual([[1.]], c.numpy())
+ self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+
+
+if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()
diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py
index 1434e965e3d7eaeca94ad0fa97498f884e30e115..d68d32057a367776d5b70d5ac21d5618297c605d 100644
--- a/tensorflow/compiler/tests/oom_test.py
+++ b/tensorflow/compiler/tests/oom_test.py
@@ -22,6 +22,8 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@@ -42,20 +44,33 @@ class OutOfMemoryTest(xla_test.XLATestCase):
"""
def test_loop():
- size = 2e8
+ size = int(2e8)
while True:
with self.test_session():
- # Force the compiled code to not be constant by feeding in an addend.
- p = array_ops.placeholder(dtypes.float32, shape=[])
+ # Force the compiled code to not be constant by feeding in a
+ # parameter.
+ p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
with self.test_scope():
- # Create a large R1 tensor.
- c = array_ops.zeros([size, 1]) + p
+ # Create a computation that produces a large R1 tensor as an
+ # intermediate result. Reduce it down so that if this file was
+ # compiled without --config=cuda, we don't force a D2H copy of a
+ # large tensor and potentially OOM the host.
+ #
+ # This is a bit tricky because XLA:GPU doesn't currently support RNG
+ # ops. Here we rely on the fact that XLA doesn't do algebraic
+ # simplifications on conv(, ).
+ c = math_ops.reduce_sum(
+ nn_ops.convolution(
+ array_ops.ones([1, size, 1]),
+ p,
+ padding='SAME',
+ data_format='NWC'))
- c.eval(feed_dict={p: 1.0})
+ c.eval(feed_dict={p: [[[1.0]], [[2.0]]]})
size *= 2
self.assertRaises(errors.ResourceExhaustedError, test_loop)
-if __name__ == "__main__":
+if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 2c084b04fa2f67ad0d86508109522d7bead206eb..7420724bdbeab63b39542ada59328621febad895 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import itertools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
@@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+class ReduceOpPrecisionTest(XLATestCase):
+
+ def _testReduceSum(self,
+ expected_result,
+ dtype,
+ test_inputs,
+ rtol=1e-3,
+ atol=1e-4):
+ """Tests reduce sum on a list of input arrays.
+
+ For each array in test_inputs, check that performing reduce sum on the array
+ produces a value that is close to the expected result.
+
+ Args:
+ expected_result: the expected result.
+ dtype: the data type of the reduce sum operation.
+ test_inputs: a list of input arrays for the reduce sum operation.
+ rtol: the relative error.
+ atol: the absolute error.
+ """
+
+ for test_input in test_inputs:
+ with self.test_session() as sess:
+ with self.test_scope():
+ a = array_ops.placeholder(dtype)
+ index = array_ops.placeholder(dtypes.int32)
+ out = math_ops.reduce_sum(a, index)
+ result = sess.run(out, {
+ a: np.array(test_input, dtype=dtype),
+ index: [0]
+ })
+ # Compare the results using float32 type.
+ self.assertAllClose(
+ np.float32(result),
+ np.float32(expected_result),
+ rtol=rtol,
+ atol=atol)
+
+ def testReduceSumF16(self):
+ """Tests the reduce sum of float16 doesn't lose too much precision."""
+
+ if np.float16 not in self.all_types:
+ return
+
+ f16_max = np.finfo(np.float16).max
+ self._testReduceSum(
+ f16_max, np.float16,
+ itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3))
+
+ def testReduceSumBF16(self):
+ """Tests the reduce sum of bfloat16 doesn't lose too much precision."""
+
+ if dtypes.bfloat16.as_numpy_dtype not in self.all_types:
+ return
+
+ bf16_max = np.float32(dtypes.bfloat16.max)
+ f32_max = dtypes.float32.max
+ value = min(bf16_max, f32_max - bf16_max)
+ self._testReduceSum(
+ dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
+ itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 4336ebdbd184a081619f0a6951dd4514735c6eb6..b6f8390a45d43bf7666b90e14cc6ff2f3f61947e 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase):
# seed were not fixed.
self.assertTrue(self._chi_squared(y, 10) < 16.92)
+ def testRandomNormalIsFinite(self):
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self._random_types():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless.stateless_random_uniform(
+ shape=[10000], seed=seed_t, dtype=dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertTrue(np.all(np.isfinite(y)))
+
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 942504e6bd4c9ce93c9482251823efcbb46ab1c8..4fca51f54d320e843343f80d7df1177f80f1d99f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -81,7 +81,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
- "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -168,9 +168,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -215,7 +215,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index b20c1ffc7d8956f3f5530ee63e9b711a26439be5..8115a26210a8e9e95e851f350e34dcdfa2519a64 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector& expressions,
std::vector* args) {
auto builder = ctx->builder();
+ auto client = ctx->compiler()->client();
std::vector compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
@@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.kind = XlaCompiler::Argument::kConstant;
TF_RET_CHECK(expressions[i]->resource() == nullptr)
<< "Input with resource is not yet implemented.";
+ TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
+ expressions[i]->handle()));
TF_ASSIGN_OR_RETURN(auto literal,
- builder->ComputeConstant(expressions[i]->handle()));
+ client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
} else {
@@ -212,7 +215,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
TF_RET_CHECK(arguments.size() == expressions.size());
- std::vector handles;
+ std::vector handles;
for (int64 i = 0; i < expressions.size(); ++i) {
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
continue;
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 00fd08b1a0750739445a124adc7ccf436a4a9b71..85ab4c41bf6a754236066260819f103970e603ae 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -114,8 +114,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -151,7 +151,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -167,7 +167,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -203,8 +203,8 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
index 5c9f66df101bfb731d6114c23933e241af5dcbeb..1e59868621475cf72f4cc8b14dafec2dd8cd5c95 100644
--- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel {
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("AddN requires at least one argument"));
- xla::ComputationDataHandle sum = ctx->Input(0);
+ xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = ctx->builder()->Add(sum, ctx->Input(i));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 931175be1111ed5f70afbdf351ee53c59c1367de..15e1815a4cf07ff50dd1431b6790d14781da590f 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
int feature_index =
@@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel {
input = builder->ConvertElementType(input, scale_type);
if (is_training_) {
- xla::ComputationDataHandle output = builder->BatchNormTraining(
+ xla::XlaOp output = builder->BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
@@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx->SetOutput(3, builder->GetTupleElement(output, 1));
ctx->SetOutput(4, builder->GetTupleElement(output, 2));
} else {
- xla::ComputationDataHandle output = builder->BatchNormInference(
+ xla::XlaOp output = builder->BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
@@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
DataType input_dtype = ctx->input_type(0);
DataType scale_dtype = ctx->input_type(2);
@@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel {
const int feature_index =
GetTensorFeatureDimIndex(input_dims, data_format_);
- xla::ComputationDataHandle x_backprop;
- xla::ComputationDataHandle scale_backprop;
- xla::ComputationDataHandle offset_backprop;
+ xla::XlaOp x_backprop;
+ xla::XlaOp scale_backprop;
+ xla::XlaOp offset_backprop;
if (is_training_) {
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
epsilon_, feature_index);
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 569950c2dfaeb61028049a263a962dfa54a62e09..642278ab994bf3cc84396f093ed56b009a1435c1 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -20,9 +20,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-void BatchToSpace(XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input, DataType input_dtype,
- const TensorShape& input_tensor_shape,
+void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
+ DataType input_dtype, const TensorShape& input_tensor_shape,
gtl::ArraySlice block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
@@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
@@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
@@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
+ xla::XlaOp permuted = b->Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
@@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
- xla::ComputationDataHandle reshaped_permuted =
- b->Reshape(permuted, reshaped_permuted_shape);
+ xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
@@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
"Cropped size must be non-negative: start: ", crop_start,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
b->Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index ed33b8ed2e823f313a9a7fe220390bc617288405..9d677f426650ea17a49e5ab1401078f04623fe97 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -60,7 +60,7 @@ class BiasOp : public XlaOpKernel {
"of the input tensor: ",
bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
- xla::ComputationDataHandle result =
+ xla::XlaOp result =
ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
ctx->SetOutput(0, result);
}
@@ -103,7 +103,7 @@ class BiasAddGradOp : public XlaOpKernel {
std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0);
std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(),
feature_dim + 1);
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 2436a6074a11ad66387b232dd1c5aa135875bfc3..f04cde878e98002d9442e0f3ec251c5197ef7969 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -34,14 +34,13 @@ namespace {
class NAME##Op : public XlaBinaryOp { \
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
- xla::ComputationDataHandle Computation( \
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \
- const gtl::ArraySlice& lhs_shape, \
- const xla::ComputationDataHandle& rhs, \
+ xla::XlaOp Computation( \
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
+ const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \
const gtl::ArraySlice& rhs_shape, \
const BCast& broadcast_helper, \
const std::vector& extend_dimensions) override { \
- xla::ComputationBuilder* b = ctx->builder(); \
+ xla::XlaBuilder* b = ctx->builder(); \
return HLO; \
} \
}; \
@@ -63,11 +62,8 @@ XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
// } else {
// return x / y;
// }
-static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b,
- DataType dtype,
- xla::ComputationDataHandle x,
- xla::ComputationDataHandle y,
- const BCast& broadcast_helper) {
+static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
@@ -87,11 +83,8 @@ XLA_MAKE_BINARY(FloorDiv,
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
-static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b,
- DataType dtype,
- xla::ComputationDataHandle x,
- xla::ComputationDataHandle y,
- const BCast& broadcast_helper) {
+static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero));
@@ -127,8 +120,7 @@ XLA_MAKE_BINARY(SqrtGrad,
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
lhs, extend_dimensions));
-static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) {
return builder->Mul(x, x);
}
@@ -175,11 +167,11 @@ class ApproximateEqualOp : public XlaOpKernel {
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
auto abs_shape = b->GetShape(abs);
OP_REQUIRES_OK(ctx, abs_shape.status());
- auto abs_type = abs_shape.ValueOrDie()->element_type();
+ auto abs_type = abs_shape.ValueOrDie().element_type();
auto result = b->Lt(
abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type));
ctx->SetOutput(0, result);
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index c52b2dcb7e9ef81fd52565dfbda05e33a52ed43a..e9d98c768572c52825fa5192ecec834889f040fe 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -33,9 +33,9 @@ class CastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp output;
if (src_dtype_ == dst_dtype_) {
output = input;
@@ -72,9 +72,9 @@ class BitcastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp output;
if (src_dtype_ == dst_dtype_) {
output = input;
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 545aa364f937b2dc972dbe7b8c18b5897aa8e5c3..835a7f568945f0bee86fe2b39491c3326726e1aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -34,7 +34,7 @@ class CategoricalOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
- const xla::ComputationDataHandle& logits = ctx->Input(0);
+ const xla::XlaOp& logits = ctx->Input(0);
TensorShape logits_shape = ctx->InputShape(0);
int64 num_samples;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples));
@@ -56,7 +56,7 @@ class CategoricalOp : public XlaOpKernel {
const int64 batch_size = logits_shape.dim_size(0);
const int64 num_classes = logits_shape.dim_size(1);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
std::array uniform_shape_array = {
{batch_size, num_samples, num_classes}};
@@ -78,7 +78,7 @@ class CategoricalOp : public XlaOpKernel {
/*broadcast_dimensions=*/{0, 2});
TensorShape softmax_shape(uniform_shape_array);
- xla::ComputationDataHandle argmax;
+ xla::XlaOp argmax;
OP_REQUIRES_OK(
ctx,
XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape,
diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
index fdf75be7b1156540d762e3bc04a51f2478f00f46..a00bc912f9f40052565446c6bf9390629af9a4cd 100644
--- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
@@ -29,7 +29,7 @@ class ClipByValueOp : public XlaOpKernel {
const TensorShape min_shape = ctx->InputShape(1);
const TensorShape max_shape = ctx->InputShape(2);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto input = ctx->Input(0);
auto min = ctx->Input(1);
auto max = ctx->Input(2);
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 1a246e8df9b2cd83147b50d960744332f8582a51..78285affa1c399ae107a9172fb85cf257457c368 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -54,7 +54,7 @@ class ConcatBaseOp : public XlaOpKernel {
// TODO(annarev): add a helper to support int64 input.
const int32 concat_dim = literal.Get({});
- std::vector values;
+ std::vector values;
std::vector shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
const int N = values.size();
@@ -70,13 +70,13 @@ class ConcatBaseOp : public XlaOpKernel {
"[",
-input_dims, ", ", input_dims, "), but got ", concat_dim));
- // Make a vector holding the ComputationDataHandles for each of
- // the inputs that has non-zero elements.
- std::vector input_data;
+ // Make a vector holding the XlaOp for each of the inputs that has non-zero
+ // elements.
+ std::vector input_data;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
- xla::ComputationDataHandle handle = values[i];
+ xla::XlaOp handle = values[i];
const TensorShape& in_shape = shapes[i];
const bool in_is_scalar = IsLegacyScalar(in_shape);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index 8f78b4c8f90cf00d5fa9ba71a78bb1c0fe280dc6..59d06c654de18c9003fe0bdc706d0c2443de6d7b 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -45,7 +45,7 @@ class ConstOp : public XlaOpKernel {
ctx->SetInvalidOutput(0);
return;
}
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// To avoid blowups for large constants filled with the same value,
// recognize that case and emit a scalar broadcast instead.
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index c0ee0c9c2ea849a692bee70bba36d32335eed9b5..627bad12f33c82e91bc3c6f3323f562bc8174056 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -47,9 +47,8 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution(
}
// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::ComputationDataHandle CreateExpandedZero(
- const TensorShape& filter_shape, DataType dtype,
- xla::ComputationBuilder* builder) {
+xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
@@ -87,8 +86,8 @@ xla::ComputationDataHandle CreateExpandedZero(
//
// Finally compare A and broadcasted B in dimension 2 amd return the result at
// the beginning of the comment.
-xla::ComputationDataHandle CreateExpandedFilterMask(
- const TensorShape& filter_shape, xla::ComputationBuilder* builder) {
+xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
@@ -96,11 +95,11 @@ xla::ComputationDataHandle CreateExpandedFilterMask(
// Create a M sized linspace and an M*N sized linspace that will be
// broadcasted into perpendicular dimensions and compared.
- xla::ComputationDataHandle input_feature_iota;
+ xla::XlaOp input_feature_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
&input_feature_iota));
- xla::ComputationDataHandle expanded_feature_iota;
+ xla::XlaOp expanded_feature_iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
input_feature * depthwise_multiplier,
&expanded_feature_iota));
@@ -126,10 +125,10 @@ xla::ComputationDataHandle CreateExpandedFilterMask(
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
// zeros for the cross-depth filters. Used to build a depthwise convolution.
-xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
- const TensorShape& filter_shape, DataType dtype,
- const xla::ComputationDataHandle& filter,
- xla::ComputationBuilder* builder) {
+xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
+ DataType dtype,
+ const xla::XlaOp& filter,
+ xla::XlaBuilder* builder) {
int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
TensorShape expanded_filter_shape =
@@ -156,10 +155,11 @@ xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
}
// Inverse of ExpandFilterForDepthwiseConvolution.
-xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
- XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
- const xla::ComputationDataHandle& filter_backprop,
- xla::ComputationBuilder* builder) {
+xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
+ const TensorShape& filter_shape,
+ DataType dtype,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
auto masked_expanded_filter = builder->Select(
@@ -248,9 +248,9 @@ class ConvOp : public XlaOpKernel {
"input and filter must have the same depth: ", in_depth,
" vs ", input_shape.dim_size(feature_dim)));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
- xla::ComputationDataHandle filter = ctx->Input(1);
+ xla::XlaOp filter = ctx->Input(1);
TensorShape expanded_filter_shape = filter_shape;
if (depthwise_) {
filter = ExpandFilterForDepthwiseConvolution(
@@ -288,7 +288,7 @@ class ConvOp : public XlaOpKernel {
&unused_output_size, &padding[i].first, &padding[i].second));
}
- xla::ComputationDataHandle conv =
+ xla::XlaOp conv =
b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
@@ -391,7 +391,7 @@ class ConvBackpropInputOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto filter = ctx->Input(1);
auto out_backprop = ctx->Input(2);
@@ -435,12 +435,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
}
// Mirror the filter in the spatial dimensions.
- xla::ComputationDataHandle mirrored_weights =
- b->Rev(filter, kernel_spatial_dims);
+ xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) mirrored_weights
- xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated(
+ xla::XlaOp in_backprop = b->ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums);
@@ -546,9 +545,9 @@ class ConvBackpropFilterOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle activations = ctx->Input(0);
- xla::ComputationDataHandle gradients = ctx->Input(2);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp activations = ctx->Input(0);
+ xla::XlaOp gradients = ctx->Input(2);
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
index 3df8c00f1b83556d7d954aedc8eeac0728251c3e..7fcd4170fb79a574663c1abffe873d4b53f471d3 100644
--- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -53,7 +53,7 @@ class CrossOp : public XlaOpKernel {
}
std::vector strides(in0_shape.dims(), 1);
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
auto in1 = ctx->Input(1);
starts.back() = 0;
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 0cf03ceb948a5165a71e902eef5264eaddbd71e9..01aa1a83e7967921f1583b3ef18ec57e452dcfea 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/bcast.h"
@@ -75,7 +75,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
// Call virtual method to emit the computation.
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
rhs_shape.dim_sizes(), bcast, extend_dimension);
@@ -85,11 +85,9 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
ctx->SetOutput(0, output);
}
-/* static */ std::pair
-XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& lhs,
- const xla::ComputationDataHandle& rhs,
- const BCast& broadcast_helper) {
+/* static */ std::pair XlaBinaryOp::Broadcast(
+ xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
+ const BCast& broadcast_helper) {
// Manually construct the broadcasting since MapN does not do
// automatic broadcasting. The bcast helper ensures that
// lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 5bc1d5fb1f08fb576df654e1f4068b6be9114096..4f92dbc8740b697322424058530b8477c35d809a 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
@@ -30,7 +30,7 @@ namespace tensorflow {
// inputs that can be broadcast to the same shape. The base class
// contains pure virtual methods to override: description is a textual
// description of the operation; and Computation adds the
-// implementation of the operation to a xla::ComputationBuilder. For most
+// implementation of the operation to a xla::XlaBuilder. For most
// arithmetic Ops XLA handles the broadcasting automatically given the input
// tensors.
class XlaBinaryOp : public XlaOpKernel {
@@ -55,10 +55,9 @@ class XlaBinaryOp : public XlaOpKernel {
// higher-rank input should be matched when broadcasting the
// lower-rank input. See comment below and the documentation on broadcasting
// in the XLA documentation.
- virtual xla::ComputationDataHandle Computation(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
- const gtl::ArraySlice& lhs_shape,
- const xla::ComputationDataHandle& rhs,
+ virtual xla::XlaOp Computation(
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
+ const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs,
const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper,
const std::vector& extend_dimensions) = 0;
@@ -67,11 +66,9 @@ class XlaBinaryOp : public XlaOpKernel {
// Helper function that performs the broadcasting described by
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
- static std::pair
- Broadcast(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& lhs,
- const xla::ComputationDataHandle& rhs,
- const BCast& broadcast_helper);
+ static std::pair Broadcast(
+ xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
+ const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 96d7809f7995634b6bc31ab801b93526d9da7e6f..23243f62462c6315e359d9621823b19fc98c6218 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -50,8 +50,8 @@ class DepthToSpaceOp : public XlaOpKernel {
const gtl::InlinedVector input_shape =
input_tensor_shape.dim_sizes();
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
@@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel {
") is not divisible by square of the block size (",
block_size_, ")"));
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -141,8 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2],
// block_size_,
// depth / (block_size_ * block_size_)]
- xla::ComputationDataHandle permuted_reshaped =
- b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -152,8 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2] * block_size_,
// depth / (block_size_ * block_size_)]
//
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 765ea922a532a085a552192348ab360c4c30ff0a..931705ba837153e1175cd9a209876ef5ec93f0fc 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -25,10 +25,10 @@ namespace tensorflow {
namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
-xla::StatusOr CreateDiagonal(
- const xla::ComputationDataHandle& input, int64 last_dim_size,
+xla::StatusOr CreateDiagonal(
+ const xla::XlaOp& input, int64 last_dim_size,
tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx,
- xla::ComputationBuilder* builder) {
+ xla::XlaBuilder* builder) {
// Create two matrices that have the following forms, and compare them:
//
// [[0, 0, 0, 0] [[0, 1, 2, 3]
@@ -38,12 +38,11 @@ xla::StatusOr CreateDiagonal(
//
// This produces a predicate matrix of the right size, with "true" on the
// diagonal.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
TF_RETURN_IF_ERROR(
XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
- xla::ComputationDataHandle iota_broadcast =
- builder->Broadcast(iota, {last_dim_size});
- xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0});
+ xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size});
+ xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0});
// If this is a batched diagonal, broadcast the mask across the other
// dimensions.
@@ -65,8 +64,7 @@ xla::StatusOr CreateDiagonal(
std::vector broadcast_dims(other_dims.begin(), other_dims.end());
broadcast_dims.push_back(1LL);
broadcast_dims.push_back(last_dim_size);
- xla::ComputationDataHandle input_broadcast =
- builder->Reshape(input, broadcast_dims);
+ xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
xla::PrimitiveType element_type;
@@ -74,7 +72,7 @@ xla::StatusOr CreateDiagonal(
DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
auto broadcast_shape =
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
- xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape);
+ xla::XlaOp zeros = Zeros(builder, broadcast_shape);
input_broadcast = builder->Add(input_broadcast, zeros);
return builder->Select(mask, input_broadcast, zeros);
@@ -85,7 +83,7 @@ class DiagOp : public XlaOpKernel {
explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("Diag op must have at an input"));
@@ -96,7 +94,7 @@ class DiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
// Picture:
// tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
@@ -112,7 +110,7 @@ class DiagOp : public XlaOpKernel {
auto diag_or_status =
CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
OP_REQUIRES_OK(ctx, diag_or_status.status());
- xla::ComputationDataHandle diag = diag_or_status.ValueOrDie();
+ xla::XlaOp diag = diag_or_status.ValueOrDie();
// Reshapes to the final shape.
std::vector new_dims(dims.size() * 2);
@@ -131,7 +129,7 @@ class DiagPartOp : public XlaOpKernel {
explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -158,7 +156,7 @@ class DiagPartOp : public XlaOpKernel {
new_dims.push_back(dims[i]);
}
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
// TODO(b/30878775): use Slice with strides when supported, in place of
// the Pad -> Reshape -> Slice.
@@ -199,7 +197,7 @@ class MatrixDiagOp : public XlaOpKernel {
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
@@ -210,7 +208,7 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
@@ -232,7 +230,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -241,7 +239,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
errors::InvalidArgument("Expected 2 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = dims[last_dim];
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index 800ef5ab98d70ad822c6efffb33db28b46ae50fe..0419de78b2ee83fd395e8bf23444fde84f30bba2 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -57,7 +57,7 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::ComputationDataHandle result = ctx->builder()->DynamicUpdateSlice(
+ xla::XlaOp result = ctx->builder()->DynamicUpdateSlice(
ctx->Input(0), ctx->Input(1), ctx->Input(2));
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index f2cd21ffb9ce88747c04f3c71e66dadeb1faf0f9..dd4a16908779508380b36f43ce2306ff2f5fb8c4 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -56,7 +56,7 @@ class DynamicStitchOp : public XlaOpKernel {
std::vector indices_input;
OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input));
- std::vector data;
+ std::vector data;
std::vector data_shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes));
@@ -136,7 +136,7 @@ class DynamicStitchOp : public XlaOpKernel {
// Look up all the children expressions that represent the data
// inputs.
- std::vector input(indices.size());
+ std::vector input(indices.size());
for (int input_num = 0; input_num < indices.size(); input_num++) {
TensorShape new_shape;
// first reshaped dimension is the number of indices for this input.
@@ -166,7 +166,7 @@ class DynamicStitchOp : public XlaOpKernel {
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
- std::vector to_concat(number_of_indices);
+ std::vector to_concat(number_of_indices);
for (int index_num = 0; index_num < number_of_indices; index_num++) {
const auto& expression = input[src_input_vector[index_num]];
// Take the appropriate slice of data.
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 2fd27c5ca7e87c8b387d9d0854b787d30e7f7b6f..ed7462c16615f7f63a174e29843c2a1675c17058 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -32,7 +32,7 @@ class EluOp : public XlaOpKernel {
explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto pred = b->Gt(ctx->Input(0), zero);
@@ -47,7 +47,7 @@ class EluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto grad = ctx->Input(0);
@@ -66,7 +66,7 @@ class SeluOp : public XlaOpKernel {
explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
@@ -86,9 +86,8 @@ class SeluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
- const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index b2970eae20a3fb71f06619f476a49d41b22bca56..6df01cabbf1d98c0299bfd808bcc6db6223c4777 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -93,7 +93,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
input_shape.DebugString()));
const int64 depth = input_shape.dim_size(feature_dim);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
// The following code is equivalent to:
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
@@ -110,7 +110,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
kernel_size * depth, &iota));
@@ -147,7 +147,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
&padding[i].first, &padding[i].second));
}
- xla::ComputationDataHandle conv =
+ xla::XlaOp conv =
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
padding, lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index 99470d70e709ddb5593c5eaae061bb897befc168..8f0de0a524c908b598c1a2165a462275346ad137 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -44,23 +44,20 @@ void CpuNudge(const float min, const float max, const float quant_min,
}
// An XLA version of CpuNudge().
-void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
- const xla::ComputationDataHandle& min,
- const xla::ComputationDataHandle& max,
+void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
+ const xla::XlaOp& min, const xla::XlaOp& max,
const float quant_min_value, const float quant_max_value,
- xla::ComputationDataHandle* nudged_min,
- xla::ComputationDataHandle* nudged_max,
- xla::ComputationDataHandle* scale) {
+ xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
+ xla::XlaOp* scale) {
*scale = b->Div(b->Sub(max, min),
XlaHelpers::FloatLiteral(b, data_type,
quant_max_value - quant_min_value));
- xla::ComputationDataHandle quant_min =
+ xla::XlaOp quant_min =
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
- xla::ComputationDataHandle zero_point_from_min =
- b->Sub(quant_min, b->Div(min, *scale));
- xla::ComputationDataHandle quant_max =
+ xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale));
+ xla::XlaOp quant_max =
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
- xla::ComputationDataHandle nudged_zero_point =
+ xla::XlaOp nudged_zero_point =
b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
b->Round(zero_point_from_min)));
@@ -68,22 +65,18 @@ void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
*nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
}
-xla::ComputationDataHandle Quantize(
- xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
- const DataType data_type,
- const xla::ComputationDataHandle& nudged_input_min,
- const xla::ComputationDataHandle& nudged_input_max,
- const xla::ComputationDataHandle& input_scale) {
- xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
- xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
- xla::ComputationDataHandle half =
- XlaHelpers::FloatLiteral(b, data_type, 0.5f);
-
- xla::ComputationDataHandle clamped =
- b->Clamp(nudged_input_min, input, nudged_input_max);
- xla::ComputationDataHandle clamped_shifted =
- b->Sub(clamped, nudged_input_min);
- xla::ComputationDataHandle rounded =
+xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
+ const DataType data_type,
+ const xla::XlaOp& nudged_input_min,
+ const xla::XlaOp& nudged_input_max,
+ const xla::XlaOp& input_scale) {
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
+ xla::XlaOp inv_scale = b->Div(one, input_scale);
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
+
+ xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max);
+ xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min);
+ xla::XlaOp rounded =
b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
}
@@ -111,18 +104,18 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min =
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
- xla::ComputationDataHandle nudged_input_max =
+ xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::ComputationDataHandle input_scale =
+ xla::XlaOp input_scale =
XlaHelpers::FloatLiteral(b, data_type, input_scale_);
- xla::ComputationDataHandle output = Quantize(
- b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
+ nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
@@ -159,23 +152,22 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle gradient = ctx->Input(0);
+ xla::XlaOp gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
- xla::ComputationDataHandle input = ctx->Input(1);
+ xla::XlaOp input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min =
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
- xla::ComputationDataHandle nudged_input_max =
+ xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::ComputationDataHandle between_nudged_min_max =
+ xla::XlaOp between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::ComputationDataHandle zeroes = b->Broadcast(
- XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
- xla::ComputationDataHandle output =
- b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type),
+ gradient_shape.dim_sizes());
+ xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output);
}
@@ -204,18 +196,18 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- xla::ComputationDataHandle input_min = ctx->Input(1);
- xla::ComputationDataHandle input_max = ctx->Input(2);
+ xla::XlaOp input_min = ctx->Input(1);
+ xla::XlaOp input_max = ctx->Input(2);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::ComputationDataHandle output = Quantize(
- b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
+ nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
@@ -243,47 +235,43 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle gradient = ctx->Input(0);
+ xla::XlaOp gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
- xla::ComputationDataHandle input = ctx->Input(1);
+ xla::XlaOp input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(data_type);
- xla::ComputationDataHandle input_min = ctx->Input(2);
- xla::ComputationDataHandle input_max = ctx->Input(3);
+ xla::XlaOp input_min = ctx->Input(2);
+ xla::XlaOp input_max = ctx->Input(3);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::ComputationDataHandle between_nudged_min_max =
+ xla::XlaOp between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
- xla::ComputationDataHandle zeroes =
- b->Broadcast(zero, gradient_shape.dim_sizes());
- xla::ComputationDataHandle output0 =
- b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
+ xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes());
+ xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output0);
- xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
- xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes);
- xla::ComputationDataHandle reduce1 = b->ReduceAll(
+ xla::XlaOp below_min = b->Lt(input, nudged_input_min);
+ xla::XlaOp select1 = b->Select(below_min, gradient, zeroes);
+ xla::XlaOp reduce1 = b->ReduceAll(
XlaHelpers::ConvertElementType(b, select1, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::ComputationDataHandle output1 =
- XlaHelpers::ConvertElementType(b, reduce1, data_type);
+ xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type);
ctx->SetOutput(1, output1);
- xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
- xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes);
- xla::ComputationDataHandle reduce2 = b->ReduceAll(
+ xla::XlaOp above_max = b->Gt(input, nudged_input_max);
+ xla::XlaOp select2 = b->Select(above_max, gradient, zeroes);
+ xla::XlaOp reduce2 = b->ReduceAll(
XlaHelpers::ConvertElementType(b, select2, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::ComputationDataHandle output2 =
- XlaHelpers::ConvertElementType(b, reduce2, data_type);
+ xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type);
ctx->SetOutput(2, output2);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index a4f3c1c3ad9a928e0552c388a25ed9fcb08edabb..933924cad1c7cac2879bd4720cb21ffc33c23f50 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -62,9 +62,8 @@ class GenericFftOp : public XlaOpKernel {
}
}
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle fft =
- b->Fft(ctx->Input(0), fft_type_, fft_length);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length);
ctx->SetOutput(0, fft);
}
@@ -82,9 +81,11 @@ class FFTOp : public GenericFftOp {
explicit FFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {}
};
-REGISTER_XLA_OP(Name("FFT"), FFTOp<1>);
-REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>);
-REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>);
+REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>);
+REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ FFTOp<2>);
+REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ FFTOp<3>);
template
class IFFTOp : public GenericFftOp {
@@ -92,9 +93,12 @@ class IFFTOp : public GenericFftOp {
explicit IFFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {}
};
-REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>);
-REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>);
-REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>);
+REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<1>);
+REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<2>);
+REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<3>);
template
class RFFTOp : public GenericFftOp {
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index eaa13b8dfacce9aaca42ce5fcdfa467ce7fa7b7f..e4467a0fb138ed7919af62ed032c0f5abee3e4f6 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -48,7 +48,7 @@ class FillOp : public XlaOpKernel {
0, {dims_shape.num_elements()}, &dims_literal));
// Convert the dims literal into a vector that we can pass to
- // ComputationBuilder.
+ // XlaBuilder.
std::vector broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
@@ -56,7 +56,7 @@ class FillOp : public XlaOpKernel {
}
// Look up the value input, reshaping to a scalar if it was a
// 'legacy' scalar (secretly a vector).
- xla::ComputationDataHandle data = ctx->Input(1);
+ xla::XlaOp data = ctx->Input(1);
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
data = ctx->builder()->Reshape(data, {});
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 0b79cb0916ee8a7d0e26c5dc12557639336f8ab1..d13e25bcddae16d0cd630403219657121b80868d 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -26,13 +26,11 @@ limitations under the License.
namespace tensorflow {
-Status XlaGather(const xla::ComputationDataHandle& input,
- const TensorShape& input_shape,
- const xla::ComputationDataHandle& indices,
- const TensorShape& indices_shape, int64 axis,
- bool indices_are_nd, DataType dtype, DataType index_type,
- xla::ComputationBuilder* builder,
- xla::ComputationDataHandle* gather_output) {
+Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
+ const xla::XlaOp& indices, const TensorShape& indices_shape,
+ int64 axis, bool indices_are_nd, DataType dtype,
+ DataType index_type, xla::XlaBuilder* builder,
+ xla::XlaOp* gather_output) {
// There is no deep reason why we need this precondition, but this is the only
// combination that is used and tested today.
CHECK(!indices_are_nd || axis == 0);
@@ -153,7 +151,7 @@ class GatherOp : public XlaOpKernel {
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
auto indices = context->Input(1);
@@ -182,7 +180,7 @@ class GatherOp : public XlaOpKernel {
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64"));
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, axis,
/*indices_are_nd=*/false, input_type(0), index_type,
@@ -220,10 +218,10 @@ class GatherNdOp : public XlaOpKernel {
indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
params_shape.dims()));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto params = context->Input(0);
auto indices = context->Input(1);
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
indices_shape, /*axis=*/0,
/*indices_are_nd=*/true, params_type,
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
index f9376f0eabdc0f0c565eb4b9f86425de96b5aa22..d898e43b858bac706d524c7c271f48b1b5fa258f 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
@@ -33,13 +33,11 @@ namespace tensorflow {
// If `indices_are_nd` is true, the last dimension of `indices` are treated as
// a multidimensional index values. Otherwise, `indices` is treated as a tensor
// of scalar indices.
-Status XlaGather(const xla::ComputationDataHandle& input,
- const TensorShape& input_shape,
- const xla::ComputationDataHandle& indices,
- const TensorShape& indices_shape, int64 axis,
- bool indices_are_nd, DataType dtype, DataType index_type,
- xla::ComputationBuilder* builder,
- xla::ComputationDataHandle* gather_output);
+Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
+ const xla::XlaOp& indices, const TensorShape& indices_shape,
+ int64 axis, bool indices_are_nd, DataType dtype,
+ DataType index_type, xla::XlaBuilder* builder,
+ xla::XlaOp* gather_output);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index eefbe55c815d80a608bdf62d454a69d722adb158..8b9b026643cf35216a2082dfcce9270c017bd14f 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -37,7 +37,7 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
// TODO(b/35949885): There is duplication here with the handling of the
// while_op. Refactor the common code out/rework.
void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
OP_REQUIRES(ctx, cond_type_ == DT_BOOL,
errors::InvalidArgument(
@@ -48,7 +48,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
- std::vector inputs(input_types_.size());
+ std::vector inputs(input_types_.size());
std::vector arguments(input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
@@ -175,19 +175,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
"Mismatch in resource of then and else branch for resource ", i));
}
- xla::ComputationDataHandle outputs =
+ xla::XlaOp outputs =
b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
b->Tuple(inputs), *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
- xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i);
+ xla::XlaOp output_handle = b->GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
if (shape_or.ok()) {
LOG(INFO) << "Shape for output " << i << ": "
- << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie());
+ << xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
} else {
LOG(INFO) << "Shape unknown for output " << i;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 5eeda79a935e8194a596d322b52add27846d378c..1568b33679963c1a6630525f60560180d40b8d53 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -23,10 +23,9 @@ namespace {
// Converts 'input' from RGB format to HSV format.
// 'shape' is the shape of the red/green/blue tensors.
-std::array RGBToHSV(
- XlaOpKernelContext* ctx, xla::ComputationBuilder* b,
- const std::array& rgb, DataType dtype,
- const TensorShape& shape) {
+std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
+ const std::array& rgb,
+ DataType dtype, const TensorShape& shape) {
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
@@ -54,12 +53,12 @@ std::array RGBToHSV(
}
// Converts 'input' from HSV format to RGB format.
-std::array HSVToRGB(
- xla::ComputationBuilder* b,
- const std::array& hsv, DataType dtype) {
- xla::ComputationDataHandle hue = hsv[0];
- xla::ComputationDataHandle saturation = hsv[1];
- xla::ComputationDataHandle value = hsv[2];
+std::array HSVToRGB(xla::XlaBuilder* b,
+ const std::array& hsv,
+ DataType dtype) {
+ xla::XlaOp hue = hsv[0];
+ xla::XlaOp saturation = hsv[1];
+ xla::XlaOp value = hsv[2];
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
@@ -95,16 +94,16 @@ class RGBToHSVOp : public XlaOpKernel {
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
@@ -133,15 +132,15 @@ class HSVToRGBOp : public XlaOpKernel {
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle hue =
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp hue =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle saturation =
+ xla::XlaOp saturation =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle value =
+ xla::XlaOp value =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
@@ -174,9 +173,9 @@ class AdjustContrastOpV2 : public XlaOpKernel {
errors::InvalidArgument("contrast_factor must be scalar: ",
factor_shape.DebugString()));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle factor = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp factor = context->Input(1);
DataType type = context->input_type(0);
@@ -221,19 +220,19 @@ class AdjustSaturationOp : public XlaOpKernel {
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle scale = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp scale = context->Input(1);
DataType type = context->input_type(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
@@ -271,19 +270,19 @@ class AdjustHueOp : public XlaOpKernel {
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle delta = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp delta = context->Input(1);
DataType type = context->input_type(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index f36b3f594826c27b7866d956c855aa3638db9cb4..9058cbc74762576c7e6f8ec1b2b0f6b247ac0502 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -99,9 +99,9 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
return dims;
}
-xla::ComputationDataHandle MakeBilinearResizeKernel(
- xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size,
- int64 channels) {
+xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
+ gtl::ArraySlice kernel_size,
+ int64 channels) {
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
@@ -120,7 +120,7 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
return kernel;
};
- xla::ComputationDataHandle channels_iota;
+ xla::XlaOp channels_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
@@ -139,10 +139,12 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
/*broadcast_dimensions=*/{0});
}
-xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input,
- const int num_spatial_dims, std::vector in_size,
- std::vector out_size, const int64 channels) {
+xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
+ const xla::XlaOp& input,
+ const int num_spatial_dims,
+ std::vector in_size,
+ std::vector out_size,
+ const int64 channels) {
// Picture for a 1x3 to 1x4 resize:
// stride = 2, kernel size = 3
// Input:
@@ -168,9 +170,9 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, out_size);
- xla::ComputationDataHandle kernel =
+ xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- xla::ComputationDataHandle output = builder->ConvGeneralDilated(
+ xla::XlaOp output = builder->ConvGeneralDilated(
input, kernel, dims.stride,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -189,10 +191,12 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
return output;
}
-xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad,
- const int num_spatial_dims, std::vector in_size,
- std::vector grad_size, const int64 channels) {
+xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
+ const xla::XlaOp& grad,
+ const int num_spatial_dims,
+ std::vector in_size,
+ std::vector grad_size,
+ const int64 channels) {
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, grad_size);
@@ -210,7 +214,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
}
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
- xla::ComputationDataHandle kernel =
+ xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
// Broadcast the input kernel where the forward op expanded from a size == 1
@@ -223,7 +227,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
}
}
- xla::ComputationDataHandle output = builder->ConvGeneralDilated(
+ xla::XlaOp output = builder->ConvGeneralDilated(
grad, kernel, /*window_strides=*/dims.kernel_size,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -258,7 +262,7 @@ class ResizeBilinearOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == 4,
@@ -283,7 +287,7 @@ class ResizeBilinearOp : public XlaOpKernel {
const int num_spatial_dims = 2;
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
// If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
// dimension i.
@@ -318,7 +322,7 @@ class ResizeBilinearOp : public XlaOpKernel {
// from image of size axb -> cxd is same as resizing axb -> exf -> cxd.
//
// This makes the convolutions kernels smaller and the operation faster.
- xla::ComputationDataHandle output = input;
+ xla::XlaOp output = input;
while (in_size != out_size) {
if (in_size[0] != 1 && in_size[1] != 1) {
std::vector k = {
@@ -369,7 +373,7 @@ class ResizeBilinearGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape input_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, input_shape.dims() == 4,
@@ -406,9 +410,9 @@ class ResizeBilinearGradOp : public XlaOpKernel {
const int num_spatial_dims = 2;
- xla::ComputationDataHandle grad = ctx->Input(0);
+ xla::XlaOp grad = ctx->Input(0);
- xla::ComputationDataHandle output = grad;
+ xla::XlaOp output = grad;
while (in_size != grad_size) {
if (in_size[0] != 1 && in_size[1] != 1) {
std::vector k = {
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 7bf4b435f526afa93d8a218b191928acb932cd6b..36eb4c75454ed82804c40b82e5dbaec2eef0a719 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -61,10 +61,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
DataType index_type = output_type(0);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaOp output;
if (is_min_) {
OP_REQUIRES_OK(ctx,
XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0),
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index b1f3c3c298ce0cadf38b9bda715761fe7e2896d7..2c2d88486fda99d2380382a3e2f633f5bdc7478c 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -71,10 +71,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(),
errors::InvalidArgument(
"ArgMax implementation requires a CustomCall on CPU"));
- xla::ComputationBuilder& b = *ctx->builder();
+ xla::XlaBuilder& b = *ctx->builder();
// XLA passes to the function, so it is not included here.
- std::vector args;
+ std::vector args;
args.push_back(ctx->Input(0));
args.push_back(b.ConstantLiteral(
*xla::Literal::CreateR1(input_shape.dim_sizes())));
@@ -91,7 +91,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_1d.cc.
- xla::ComputationDataHandle output;
+ xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index c177f08d9c4687bb13b98a4328bb3960519799c4..1decf7d72d72bb697477e7f841ced2a1a0d5fbe9 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -33,7 +33,7 @@ class L2LossOp : public XlaOpKernel {
std::iota(dims.begin(), dims.end(), 0);
DataType dtype = ctx->input_type(0);
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
// output = sum(t ** 2) / 2
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index 1cfee3070f384af0a7441a9c860c530dd1b42187..39fbf98a6274918840e9e351470f04c2d80c5d01 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -38,8 +38,8 @@ class LRNOp : public XlaOpKernel {
OP_REQUIRES(ctx, in_shape.dims() == 4,
errors::InvalidArgument("in must be 4-dimensional"));
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
// sqr_sum[a, b, c, d] =
// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
@@ -111,10 +111,10 @@ class LRNGradOp : public XlaOpKernel {
"input_grads, input_image, and out_image should have the same "
"shape"));
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle in_grads = ctx->Input(0);
- xla::ComputationDataHandle in_image = ctx->Input(1);
- xla::ComputationDataHandle out_image = ctx->Input(2);
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp in_grads = ctx->Input(0);
+ xla::XlaOp in_image = ctx->Input(1);
+ xla::XlaOp out_image = ctx->Input(2);
// This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python
// pseudo-code, the Eigen code does this for each spatial position:
@@ -166,7 +166,7 @@ class LRNGradOp : public XlaOpKernel {
auto dy_reduced =
XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0));
- xla::ComputationDataHandle gradients = builder->Add(
+ xla::XlaOp gradients = builder->Add(
builder->Mul(in_image, dy_reduced),
builder->Mul(in_grads,
builder->Pow(norm, builder->ConstantR0(-beta_))));
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index 886baf8115243a22b7255a3961c914d4cf6c2ed5..6949b296f4b9afe4a0c9152c763a9ad233b9f595 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -66,8 +66,8 @@ class MatMulOp : public XlaOpKernel {
a_shape.DebugString(), ", In[1]: ",
b_shape.DebugString()));
- xla::ComputationDataHandle a = ctx->Input(0);
- xla::ComputationDataHandle b = ctx->Input(1);
+ xla::XlaOp a = ctx->Input(0);
+ xla::XlaOp b = ctx->Input(1);
if (is_sparse_) {
if (a_type_ == DT_BFLOAT16) {
a = ctx->builder()->ConvertElementType(a, xla::F32);
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index faa415a97b053b4b11d015fefcd430210b98118a..fbd5dc0fdad4483aadbe9bc263cc1f7a034cee09 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -44,10 +44,10 @@ class MatrixBandPartOp : public XlaOpKernel {
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in_shape.DebugString()));
- xla::ComputationBuilder* builder = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle num_lower = context->Input(1);
- xla::ComputationDataHandle num_upper = context->Input(2);
+ xla::XlaBuilder* builder = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp num_lower = context->Input(1);
+ xla::XlaOp num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
@@ -58,10 +58,10 @@ class MatrixBandPartOp : public XlaOpKernel {
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::ComputationDataHandle iota_m;
+ xla::XlaOp iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
- xla::ComputationDataHandle iota_n;
+ xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index b2940bdcff75a087c914fdad0cb2426276e41aff..db53f6fef8d6bf901c8281f50791ca6766c46efd 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -54,16 +54,16 @@ class MatrixSetDiagOp : public XlaOpKernel {
input_shape.DebugString(),
" and diagonal shape: ", diag_shape.DebugString()));
- xla::ComputationBuilder* builder = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle diag = context->Input(1);
+ xla::XlaBuilder* builder = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp diag = context->Input(1);
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
- xla::ComputationDataHandle iota_m;
+ xla::XlaOp iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
- xla::ComputationDataHandle iota_n;
+ xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
auto indicator = builder->Eq(iota_m,
builder->Broadcast(iota_n, {m}),
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index 05a36a031ad73be289604da1b7e56203ff12fbf5..7e9de3ef9b245c113cc143128fe58e7e017a361c 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -25,10 +25,11 @@ class MirrorPadOp : public XlaOpKernel {
public:
explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
- xla::StatusOr DoMirrorPad(
- const xla::ComputationDataHandle& t, const xla::Shape& original_shape,
- const xla::Literal& pad_literal, xla::ComputationBuilder* b) {
- xla::ComputationDataHandle accum = t;
+ xla::StatusOr DoMirrorPad(const xla::XlaOp& t,
+ const xla::Shape& original_shape,
+ const xla::Literal& pad_literal,
+ xla::XlaBuilder* b) {
+ xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
auto t_rev = b->Rev(accum, {dimno});
@@ -76,12 +77,12 @@ class MirrorPadOp : public XlaOpKernel {
OP_REQUIRES_OK(
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
- xla::StatusOr> in0_shape = b->GetShape(in0);
+ xla::StatusOr in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
- xla::StatusOr accum_status =
- DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b);
+ xla::StatusOr accum_status =
+ DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b);
OP_REQUIRES_OK(ctx, accum_status.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
index 9f7c9913802d311895479b914b66553e135aa426..cac2eea96eeed723b2a63bc9193070cad04b005d 100644
--- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
@@ -62,7 +62,7 @@ class OneHotOp : public XlaOpKernel {
ctx, depth >= 0,
errors::InvalidArgument("depth must be non-negative, got: ", depth));
- xla::ComputationDataHandle one_hot;
+ xla::XlaOp one_hot;
OP_REQUIRES_OK(
ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
indices_shape, ctx->Input(0), ctx->Input(2),
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
index a4318e29d2532faf1f0cc6bb9418d29c2df20cd4..aecaabb6dcf46bdd6ae3da929448d6370acb989b 100644
--- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -43,7 +43,7 @@ class PackOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- std::vector values;
+ std::vector values;
std::vector shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
const int num = values.size();
@@ -69,7 +69,7 @@ class PackOp : public XlaOpKernel {
-expanded_num_dims, ", ",
expanded_num_dims, ")"));
- std::vector reshaped_inputs(num);
+ std::vector reshaped_inputs(num);
TensorShape child_shape(shapes[0]);
child_shape.InsertDim(axis, 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 791351637aee61c5fdd911dd8a48959990514395..7c95475e7b1f02183e44f73f116a4aeb25f05c09 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -70,7 +70,7 @@ class PadOp : public XlaOpKernel {
}
// PadV2 added a "constant_values" input that indicates the pad value.
- xla::ComputationDataHandle constant_values;
+ xla::XlaOp constant_values;
if (ctx->num_inputs() == 3) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
errors::InvalidArgument("constant_values must be a scalar."));
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 5f635dd1bc6122cfcac8163baafd95b13f157715..f8e7b48a0fd94835964aea033ad33523150067b4 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -66,15 +66,15 @@ class PoolingOp : public XlaOpKernel {
int num_dims() const { return num_spatial_dims_ + 2; }
// Method that builds an initial value to use in reductions.
- virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0;
+ virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;
// The reduction operation to apply to each window.
- virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0;
+ virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;
// A post-processing operation to apply on the outputs of the ReduceWindow.
- virtual xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) = 0;
+ virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) = 0;
void Compile(XlaOpKernelContext* ctx) override {
std::vector ksize = ksize_;
@@ -110,7 +110,7 @@ class PoolingOp : public XlaOpKernel {
" operator must have ", num_dims(),
" dimensions"));
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
auto reduce = ctx->builder()->ReduceWindow(
@@ -135,17 +135,17 @@ class MaxPoolOp : public PoolingOp {
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {}
- xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override {
+ xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return XlaHelpers::MinValue(b, reduction_type_);
}
- const xla::Computation* Reduction(XlaOpKernelContext* ctx) override {
+ const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateMax(reduction_type_);
}
- xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) override {
+ xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) override {
return output;
}
};
@@ -176,9 +176,9 @@ REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
-static xla::ComputationDataHandle AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape, xla::Padding padding,
+static xla::XlaOp AvgPoolDivideByCount(
+ XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape, xla::Padding padding,
const std::vector& ksize, const std::vector& stride,
int num_spatial_dims, TensorFormat data_format) {
if (padding == xla::Padding::kValid) {
@@ -234,17 +234,17 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override {
+ xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return XlaHelpers::Zero(b, reduction_type_);
}
- const xla::Computation* Reduction(XlaOpKernelContext* ctx) override {
+ const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateAdd(reduction_type_);
}
- xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) override {
+ xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) override {
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
ksize_, stride_, num_spatial_dims_,
data_format_);
@@ -344,11 +344,10 @@ class MaxPoolGradOp : public XlaOpKernel {
xla::PrimitiveType element_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
- xla::ComputationDataHandle init_value =
- XlaHelpers::Zero(ctx->builder(), input_type(2));
+ xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
auto select = CreateScalarGeComputation(element_type, ctx->builder());
auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
- xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
+ xla::XlaOp gradients = ctx->builder()->SelectAndScatter(
input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
scatter);
@@ -462,7 +461,7 @@ class AvgPoolGradOp : public XlaOpKernel {
// The input gradients are computed by a convolution of the output gradients
// and the filter, with some appropriate padding. See the comment at the top
// of conv_grad_ops.h for details.
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
auto dtype = input_type(1);
xla::Padding xla_padding =
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 4171e076ff6d9dd4f809454377620324d1fe5ae4..661cd5923e1023eaf89a6bc4f56fcc362c8bcfb6 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -35,7 +35,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
// Comments taken from semantics description at
@@ -46,8 +46,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// m = max(abs(input_min), abs(input_max)) if range_given is true,
// m = max(abs(min_elem(input)),
// abs(max_elem(input))) otherwise.
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input_min, input_max;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input_min, input_max;
if (range_given_) {
double input_min_value, input_max_value;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
@@ -55,14 +55,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
} else {
- const xla::Computation* fmax = ctx->GetOrCreateMax(data_type);
- const xla::Computation* fmin = ctx->GetOrCreateMin(data_type);
+ const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
+ const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
input_min =
b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
input_max =
b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
}
- xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max));
+ xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max));
// Next, we choose our fixed-point quantization buckets, [min_fixed,
// max_fixed]. If signed_input is true, this is
@@ -85,7 +85,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// From this we compute our scaling factor, s:
//
// s = (max_fixed - min_fixed) / (2 * m).
- xla::ComputationDataHandle s =
+ xla::XlaOp s =
b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
@@ -93,7 +93,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// e is transformed into e':
//
// e' = (e * s).round_to_nearest() / s.
- xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s);
+ xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index c0994c434bca5174eaee7b9e63e10432d9c2ed8d..5f5bd586376ab368e443671ac8a5de23a5fd604b 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -41,9 +41,9 @@ class RandomUniformOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle result = b->RngUniform(
- XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -100,11 +100,11 @@ class RandomStandardNormalOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Normal distribution with a mean of 0 and a standard deviation of 1:
- xla::ComputationDataHandle result = b->RngNormal(
- XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -130,19 +130,18 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::Shape xla_element_shape =
xla::ShapeUtil::MakeShape(xla_shape.element_type(), {});
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
- xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
- xla::ComputationDataHandle candidate =
- b->RngNormal(mean, stddev, xla_shape);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
+ xla::XlaOp stddev = XlaHelpers::One(b, dtype);
+ xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape);
- auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) {
+ auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) {
return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0);
};
- auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate,
- xla::ComputationBuilder* b) {
- xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b));
- xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b));
+ auto out_of_range_mask = [two_sd](xla::XlaOp candidate,
+ xla::XlaBuilder* b) {
+ xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b));
+ xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b));
return b->Or(too_large, too_small);
};
@@ -152,35 +151,32 @@ class TruncatedNormalOp : public XlaOpKernel {
// out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
// candidate = select(out_of_range_mask, rng_normal(), candidate)
// }
- std::unique_ptr test_builder =
+ std::unique_ptr test_builder =
b->CreateSubBuilder("truncated_normal_test");
{
auto* b = test_builder.get();
- xla::ComputationDataHandle candidate =
- b->Parameter(0, xla_shape, "candidate");
- xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b);
+ xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
+ out_of_range_mask(candidate, b);
OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status());
}
- std::unique_ptr body_builder =
+ std::unique_ptr body_builder =
b->CreateSubBuilder("truncated_normal_body");
{
auto* b = body_builder.get();
- xla::ComputationDataHandle candidate =
- b->Parameter(0, xla_shape, "candidate");
- xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b);
- xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
- xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
+ xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
+ xla::XlaOp to_resample = out_of_range_mask(candidate, b);
+ xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
+ xla::XlaOp stddev = XlaHelpers::One(b, dtype);
b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate);
}
- xla::StatusOr test_computation = test_builder->Build();
+ xla::StatusOr test_computation = test_builder->Build();
OP_REQUIRES_OK(ctx, test_computation.status());
- xla::StatusOr body_computation = body_builder->Build();
+ xla::StatusOr body_computation = body_builder->Build();
OP_REQUIRES_OK(ctx, body_computation.status());
- xla::ComputationDataHandle result =
- b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(),
- candidate);
+ xla::XlaOp result = b->While(test_computation.ValueOrDie(),
+ body_computation.ValueOrDie(), candidate);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index cb144bea9e429b7c8bcc3d07f688ed6a254c3be0..08894489ac77bbbe4ddb067c06a6d031a537697d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -65,7 +64,7 @@ class ReduceWindowOp : public XlaOpKernel {
"rank (",
padding_high_.size(), " vs. ", rank, ")"));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -95,15 +94,15 @@ class ReduceWindowOp : public XlaOpKernel {
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
// Wraps the reducer in a computation that unpacks the output tuple.
- xla::Computation wrapper;
+ xla::XlaComputation wrapper;
{
- std::unique_ptr cb =
+ std::unique_ptr cb =
builder->CreateSubBuilder("wrapper");
auto x = cb->Parameter(0, scalar_shape, "x");
auto y = cb->Parameter(1, scalar_shape, "y");
auto outputs = cb->Call(*reducer.computation, {x, y});
cb->GetTupleElement(outputs, 0);
- xla::StatusOr result = cb->Build();
+ xla::StatusOr result = cb->Build();
OP_REQUIRES_OK(context, result.status());
wrapper = std::move(result.ValueOrDie());
}
@@ -113,7 +112,7 @@ class ReduceWindowOp : public XlaOpKernel {
padding[i] = {padding_low_[i], padding_high_[i]};
}
- xla::ComputationDataHandle output = builder->ReduceWindowWithGeneralPadding(
+ xla::XlaOp output = builder->ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), wrapper, window_dimensions_,
window_strides_, padding);
context->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 812d258cd1677e18ef49952044126c76a2f55b19..0f425637795e9633a8e36f921000ee2f5e25813a 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -30,13 +30,11 @@ class SumOp : public XlaReductionOp {
explicit SumOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::Zero(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Add(scalar_lhs, scalar_rhs);
}
};
@@ -49,14 +47,12 @@ class ProdOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::One(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Mul(scalar_lhs, scalar_rhs);
}
};
@@ -69,14 +65,12 @@ class MinOp : public XlaReductionOp {
explicit MinOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::MaxValue(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Min(scalar_lhs, scalar_rhs);
}
};
@@ -88,14 +82,12 @@ class MaxOp : public XlaReductionOp {
explicit MaxOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::MinValue(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Max(scalar_lhs, scalar_rhs);
}
};
@@ -108,20 +100,17 @@ class MeanOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::Zero(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Add(scalar_lhs, scalar_rhs);
}
- xla::ComputationDataHandle BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced) override {
+ xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
return builder->Div(reduce_output, divisor);
@@ -136,14 +125,12 @@ class AllOp : public XlaReductionOp {
explicit AllOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return builder->ConstantR0(true);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->And(scalar_lhs, scalar_rhs);
}
};
@@ -155,14 +142,12 @@ class AnyOp : public XlaReductionOp {
explicit AnyOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return builder->ConstantR0(false);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Or(scalar_lhs, scalar_rhs);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index f3181f0dadc2d3f45abb145e009e2663c10490f0..2ecfb854a1c8625524d4f1199af3927edd204926 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -28,35 +28,33 @@ namespace tensorflow {
// to override: description is a textual description of the mapped
// function; InitialValue constructs the base case for the reduction;
// BuildReducer adds the implementation of the reduction lambda to a
-// xla::ComputationBuilder and BuildFinalizer adds the
+// xla::XlaBuilder and BuildFinalizer adds the
// implementation of the finalizer lambda (if there is one) to a
-// xla::ComputationBuilder.
+// xla::XlaBuilder.
class XlaReductionOp : public XlaOpKernel {
public:
XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type);
~XlaReductionOp() override {}
// Return the base case for the reduction.
- virtual xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) = 0;
+ virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// Implement the (scalar,scalar)->scalar lambda that should be
// applied to each pair of elements to be reduced. The desired
// computation should be added to 'builder' and
// '(scalar_lhs,scalar_rhs)' are the function's inputs.
- virtual void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) = 0;
+ virtual void BuildReducer(xla::XlaBuilder* builder,
+ const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) = 0;
// Applies a transformation to the output of the reduction. The desired
// computation should be added to 'builder'. Argument 'reduce_output' is the
// output of the reduction. 'num_elements_reduced' is the number of elements
// that contributed to the reduction. Returns the transformed reduction
// output, Defaults to returning 'reduce_output' unchanged.
- virtual xla::ComputationDataHandle BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced);
+ virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced);
void Compile(XlaOpKernelContext* ctx) override;
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 64fe765ae9a945c58ea60bc157b1520c83b0d8e7..4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -35,10 +35,9 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
// Unless BuildFinalizer is overridden the reduction has no
// finalizer.
-xla::ComputationDataHandle XlaReductionOp::BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced) {
+xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced) {
return reduce_output;
}
@@ -96,9 +95,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
string desc = ctx->op_kernel().name();
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
// Construct the builder for the reduction lambda.
- xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction"));
+ xla::XlaBuilder r(strings::StrCat(desc, "-reduction"));
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
@@ -110,7 +109,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
// Call virtual method to build the reduction lambda.
BuildReducer(&r, rx, ry);
- xla::Computation reduction_computation = r.Build().ConsumeValueOrDie();
+ xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 12a35529992e6160566046dd28f9321c88afec91..ba7d484d53d7258edaa5bc42fa116cf16e94835b 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -32,7 +32,7 @@ class ReluOp : public XlaOpKernel {
explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
ctx->SetOutput(0, builder->Max(zero, ctx->Input(0)));
}
@@ -43,7 +43,7 @@ class Relu6Op : public XlaOpKernel {
explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Clamp the scalar input between 0 and 6.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6);
ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six));
@@ -56,7 +56,7 @@ class ReluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
@@ -71,7 +71,7 @@ class Relu6GradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index c283e3b02c2676785952e3e17bffa671b0dabc1e..70547290eaed169599764a5d66185dde85345863 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -45,7 +45,7 @@ class RetvalOp : public XlaOpKernel {
// compilation.
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
} else {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
auto is_constant = ctx->builder()->IsConstant(input);
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index e51d386926763ecbb5a943dfb6f872e78901dc69..2872a3c4d49d0d269aa3d216887a5c32cd51f1c3 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -48,7 +48,7 @@ class ReverseOp : public XlaOpKernel {
ctx->SetOutput(0, ctx->Input(0));
return;
}
- // ComputationBuilder::Rev() requires concrete values for dimensions arg.
+ // XlaBuilder::Rev() requires concrete values for dimensions arg.
xla::Literal lax;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
std::vector revdims(x_shape.dims());
@@ -90,7 +90,7 @@ class ReverseV2Op : public XlaOpKernel {
ctx->SetOutput(0, ctx->Input(0));
return;
}
- // ComputationBuilder::Rev() requires concrete values for dimensions arg.
+ // XlaBuilder::Rev() requires concrete values for dimensions arg.
std::vector axes;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes));
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 6bc5d3adb091cd238974c5b69b7a2f8fe639cc68..0ed4c4707df71cf5f56ccfe0af506916f04bcdb5 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -54,7 +54,7 @@ class ReverseSequenceOp : public XlaOpKernel {
"), ", "(", seq_lens_shape.num_elements(),
" vs. ", input_shape.dim_size(batch_dim_)));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
const auto input = context->Input(0);
const auto seq_lens = context->Input(1);
@@ -155,7 +155,7 @@ class ReverseSequenceOp : public XlaOpKernel {
auto output = builder->GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
OP_REQUIRES_OK(
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
std::vector dims(input_shape.dims(), 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 4cfa28a0ce3d7d1f24196ef6ef2775f840b2bcf1..1819fb543317eed15b2fe0518d74aba5c564697d 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -74,7 +74,7 @@ class ScanOp : public XlaOpKernel {
return;
}
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
std::vector window_strides(input_shape.dims(), 1);
std::vector window_dims(input_shape.dims(), 1);
@@ -91,8 +91,8 @@ class ScanOp : public XlaOpKernel {
std::swap(padding[axis].first, padding[axis].second);
}
- xla::ComputationDataHandle init;
- const xla::Computation* reducer;
+ xla::XlaOp init;
+ const xla::XlaComputation* reducer;
if (sum_) {
init = XlaHelpers::Zero(builder, dtype);
reducer = ctx->GetOrCreateAdd(dtype);
diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
index 8433a29c4e203cac726ee6bf7f67a863447326ed..f2c63b4f9083ad3c7dd7cf318dc22def1e99fa9f 100644
--- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
@@ -102,7 +102,7 @@ class ScatterNdOp : public XlaOpKernel {
OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
updates_shape));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
buffer_shape.dim_sizes());
auto indices = context->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index 498342a98881df0c6ff50007eacc1d5ef6196b57..664078ca16c6d5d4b57c4a8c661ad0848f30dd7d 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -62,16 +62,16 @@ class UnsortedSegmentSum : public XlaOpKernel {
d, " differs ", data_shape.dim_size(d), " vs. ",
indices_shape.dim_size(d)));
}
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
buffer_shape.dim_sizes());
- auto combiner =
- [](xla::ComputationDataHandle a, xla::ComputationDataHandle b,
- xla::ComputationBuilder* builder) { return builder->Add(a, b); };
+ auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
+ return builder->Add(a, b);
+ };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index 8081d3c41c436324c21858124121fecfac71cefa..f9f48164d63492b057d4950abfc2ca6153e44870 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -40,7 +40,7 @@ class SelectOp : public XlaOpKernel {
"'then' and 'else' must have the same size. but received: ",
then_shape.DebugString(), " vs. ", else_shape.DebugString()));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto cond_handle = ctx->Input(0);
auto then_handle = ctx->Input(1);
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index d079b89861817a5639ac72b5ee49d76cb4506ae8..9ce01d0d44509bbcbea18afdb4210a675834bb6d 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 463788b8b461c370a8e7ab4d79a94fc0143b8b45..bbf5ee8b12186a582666121b1df5d8b7d881863e 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -43,8 +43,8 @@ class SoftmaxOp : public XlaOpKernel {
const DataType type = input_type(0);
auto logits = ctx->Input(0);
- xla::ComputationBuilder* const b = ctx->builder();
- const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ xla::XlaBuilder* const b = ctx->builder();
+ const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
@@ -76,16 +76,15 @@ class SoftmaxOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
-std::pair
-CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
- const xla::ComputationDataHandle& logits,
- const xla::ComputationDataHandle& labels) {
- const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+std::pair CrossEntropyWithLogits(
+ XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
+ const xla::XlaOp& labels) {
+ const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
const int kBatchDim = 0;
const int kClassDim = 1;
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
@@ -123,7 +122,7 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
// (where the division broadcasts along the batch dimension)
- xla::ComputationDataHandle backprop =
+ xla::XlaOp backprop =
b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
return {loss, backprop};
}
@@ -150,7 +149,7 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
- xla::ComputationDataHandle loss, backprop;
+ xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, type, logits, labels);
ctx->SetOutput(0, loss);
@@ -191,10 +190,10 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
DataType logits_type = input_type(0);
DataType indices_type = input_type(1);
- xla::ComputationDataHandle indices = ctx->Input(1);
+ xla::XlaOp indices = ctx->Input(1);
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle labels;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp labels;
OP_REQUIRES_OK(ctx,
XlaHelpers::OneHot(
builder, depth, /*axis=*/1, input_type(1), labels_shape,
@@ -207,7 +206,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
// Builds a vector of {batch_size} that is 0 if the index is in range, or
// NaN otherwise; then add that vector to the labels to force out-of-range
// values to NaNs.
- xla::ComputationDataHandle nan_or_zero = builder->Select(
+ xla::XlaOp nan_or_zero = builder->Select(
builder->And(
builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
builder->Lt(indices, XlaHelpers::IntegerLiteral(
@@ -218,7 +217,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
{batch_size}));
labels = builder->Add(labels, nan_or_zero, {0});
- xla::ComputationDataHandle loss, backprop;
+ xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
ctx->SetOutput(0, loss);
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 01b46e160d1f1f10a43faf7ca35afb42dfde6e33..ec077924b5b5af4a573c86c8d9aeb8623bd7f801 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -20,9 +20,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-void SpaceToBatch(XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input, DataType input_dtype,
- const TensorShape& input_tensor_shape,
+void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
+ DataType input_dtype, const TensorShape& input_tensor_shape,
gtl::ArraySlice block_shape,
const xla::Literal& paddings) {
const int input_rank = input_tensor_shape.dims();
@@ -46,7 +45,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
", 2] instead of ",
xla::ShapeUtil::HumanString(paddings.shape())));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
// input according to `paddings` to produce `padded` of shape `padded_shape`.
@@ -73,7 +72,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
errors::InvalidArgument(
"The product of the block dimensions must be positive"));
- xla::ComputationDataHandle padded =
+ xla::XlaOp padded =
b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
// 2. Reshape `padded` to `reshaped_padded` of shape:
@@ -101,8 +100,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_padded_shape.begin() + 1 + 2 * block_rank);
- xla::ComputationDataHandle reshaped_padded =
- b->Reshape(padded, reshaped_padded_shape);
+ xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape);
// 3. Permute dimensions of `reshaped_padded` to produce
// `permuted_reshaped_padded` of shape:
@@ -121,7 +119,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
permutation[block_rank] = 0;
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::ComputationDataHandle permuted_reshaped_padded =
+ xla::XlaOp permuted_reshaped_padded =
b->Transpose(reshaped_padded, permutation);
// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
@@ -142,8 +140,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
output_shape.begin() + 1 + block_rank);
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped_padded, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 806fda632cde64c1b37ae3b9199028d6b6b0a215..4c5886ee2a0f63d609f79fc690f457d93e284e3e 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -50,8 +50,8 @@ class SpaceToDepthOp : public XlaOpKernel {
const gtl::InlinedVector input_shape =
input_tensor_shape.dim_sizes();
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
@@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[1] / block_size_, block_size_,
// input_shape[2] / block_size_, block_size_,
// depth]
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -145,8 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_, block_size_,
// depth]
- xla::ComputationDataHandle permuted_reshaped =
- b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -156,8 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_ * block_size_ * depth]
//
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 43c15e753805352875034dfd2c70a2a1ed9a4114..8958b2e7701e62d802e37a895c14b662ecf9786a 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -124,7 +124,7 @@ class SplitVOp : public XlaOpKernel {
input_shape.dims(), "), but got ",
split_dim_orig));
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
OP_REQUIRES(ctx, input_shape.dims() > 0,
errors::InvalidArgument("Can't split a 0 dimensional input"));
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 1a78c7ab9be701d3d02285ed21604f0f856b3f1f..0fb05a2be7b1034d6c2e864643b69647d622ede7 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -38,13 +38,13 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
+Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource,
TensorShape* stack_shape) {
auto shape_or_status = builder->GetShape(resource->value());
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
- xla::Shape shape = *shape_or_status.ValueOrDie();
+ xla::Shape shape = shape_or_status.ValueOrDie();
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
stack_shape);
@@ -60,9 +60,8 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
//
// TODO(phawkins): consider changing the API of the stack operators to
// allow an optional element shape at stack construction time.
-Status MaybeInitializeStack(xla::ComputationBuilder* builder,
- XlaResource* resource, DataType dtype,
- const TensorShape& elem_shape) {
+Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource,
+ DataType dtype, const TensorShape& elem_shape) {
if (resource->type() != dtype) {
return errors::InvalidArgument(
"Stack dtype is ", DataTypeString(resource->type()),
@@ -75,8 +74,6 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
if (!resource->initialized()) {
// Stack has not been initialized.
- xla::ComputationDataHandle zero =
- XlaHelpers::Zero(builder, resource->type());
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
@@ -111,7 +108,7 @@ class StackOp : public XlaOpKernel {
// We defer initializing the Stack resource until we see the first push.
// Otherwise we do not know the shape of the stack elements.
- xla::ComputationDataHandle value;
+ xla::XlaOp value;
XlaContext& xc = XlaContext::Get(ctx);
XlaResource* resource;
string name = strings::StrCat("Stack: ", stack_name_);
@@ -138,7 +135,7 @@ class StackPushOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(1);
XlaResource* resource;
@@ -147,9 +144,9 @@ class StackPushOp : public XlaOpKernel {
// Initializes the Stack, if the element shape was not already known.
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0);
- xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1);
- xla::ComputationDataHandle value = ctx->Input(1);
+ xla::XlaOp ta = b->GetTupleElement(resource->value(), 0);
+ xla::XlaOp index = b->GetTupleElement(resource->value(), 1);
+ xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -184,7 +181,7 @@ class StackPopOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -199,9 +196,9 @@ class StackPopOp : public XlaOpKernel {
TensorShape stack_shape;
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
- xla::ComputationDataHandle state = resource->value();
- xla::ComputationDataHandle ta = b->GetTupleElement(state, 0);
- xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
+ xla::XlaOp state = resource->value();
+ xla::XlaOp ta = b->GetTupleElement(state, 0);
+ xla::XlaOp index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0(1));
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
@@ -216,8 +213,7 @@ class StackPopOp : public XlaOpKernel {
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- xla::ComputationDataHandle read =
- b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector value_shape(slice_shape.begin() + 1, slice_shape.end());
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 5bb773d97fc5ce90dabceeefd5c29d916597f5ff..a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -30,9 +30,8 @@ namespace tensorflow {
namespace {
// Rotates a 32-bit integer 'v' left by 'distance' bits.
-xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& v,
- int distance) {
+xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v,
+ int distance) {
return builder->Or(
builder->ShiftLeft(v, builder->ConstantR0(distance)),
builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance)));
@@ -40,25 +39,24 @@ xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder,
// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than
// building XOR out of other bitwise operators.
-xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& y) {
+xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const xla::XlaOp& y) {
return builder->Or(builder->And(x, builder->Not(y)),
builder->And(builder->Not(x), y));
}
-using ThreeFry2x32State = std::array;
+using ThreeFry2x32State = std::array;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
-ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder,
+ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
ThreeFry2x32State input, ThreeFry2x32State key) {
// Rotation distances specified by the Threefry2x32 algorithm.
constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24};
ThreeFry2x32State x;
- std::array ks;
+ std::array ks;
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = builder->ConstantR0(0x1BD11BDA);
for (int i = 0; i < 2; ++i) {
@@ -121,10 +119,9 @@ ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder,
// Returns a tensor of 'shape' random values uniformly distributed in the range
// [minval, maxval)
-xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& seed,
- const TensorShape& shape,
- double minval, double maxval) {
+xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
+ const TensorShape& shape, double minval,
+ double maxval) {
// Split the seed into two 32-bit scalars to form a key.
auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {});
auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {});
@@ -178,9 +175,8 @@ xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder,
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
-xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& x,
- const TensorShape& shape) {
+xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x,
+ const TensorShape& shape) {
constexpr int kDegree = 9;
constexpr std::array w_less_than_5_constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
@@ -220,7 +216,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
@@ -229,7 +225,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
- xla::ComputationDataHandle seed = ctx->Input(1);
+ xla::XlaOp seed = ctx->Input(1);
ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0));
}
@@ -257,9 +253,10 @@ class StatelessRandomNormalOp : public XlaOpKernel {
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
- xla::ComputationDataHandle seed = ctx->Input(1);
- xla::ComputationBuilder* builder = ctx->builder();
- auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0);
+ xla::XlaOp seed = ctx->Input(1);
+ xla::XlaBuilder* builder = ctx->builder();
+ auto uniform =
+ RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)),
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 6204aa4e27000fddec7f5b82b2198d37956f6aba..55254c746e5ebaf6b468c24ab59b968bf0d6260b 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -90,7 +90,7 @@ class StridedSliceOp : public XlaOpKernel {
}
}
- xla::ComputationDataHandle slice = ctx->Input(0);
+ xla::XlaOp slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
}
@@ -168,7 +168,7 @@ class StridedSliceGradOp : public XlaOpKernel {
auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
- xla::ComputationDataHandle grad = ctx->Input(4);
+ xla::XlaOp grad = ctx->Input(4);
// Undo any new/shrink axes.
grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
@@ -255,7 +255,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
&strides_tensor));
TensorShape lhs_shape;
- xla::ComputationDataHandle lhs;
+ xla::XlaOp lhs;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
@@ -284,7 +284,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
- xla::ComputationDataHandle rhs = ctx->Input(4);
+ xla::XlaOp rhs = ctx->Input(4);
gtl::InlinedVector dimensions_to_reverse;
gtl::InlinedVector slice_begin, slice_dims;
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 000b50af6bd86b7268c016865fb0856c16053ece..9adee78a1fd1fb9a12afae83197425c328b5fe7e 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -47,7 +47,7 @@ namespace {
// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
-Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+Status MaybeInitializeTensorArray(xla::XlaBuilder* builder,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
if (resource->kind() != XlaResource::kTensorArray) {
@@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
<< resource->name() << " size " << resource->tensor_array_size();
if (!resource->initialized()) {
- xla::ComputationDataHandle zero =
- XlaHelpers::Zero(builder, resource->type());
-
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
@@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
}
TensorShape shape;
TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+ XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
}
Status GetTensorArrayShape(const XlaResource* resource,
- xla::ComputationBuilder* builder,
- TensorShape* shape) {
+ xla::XlaBuilder* builder, TensorShape* shape) {
*shape = resource->shape();
shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
-// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the
// relevant slice of 'operand'.
-xla::ComputationDataHandle DynamicAddSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
- const xla::ComputationDataHandle& update,
- const gtl::ArraySlice& update_dims,
- const xla::ComputationDataHandle& start_indices) {
- xla::ComputationDataHandle current =
+xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
+ const xla::XlaOp& update,
+ const gtl::ArraySlice& update_dims,
+ const xla::XlaOp& start_indices) {
+ xla::XlaOp current =
builder->DynamicSlice(operand, start_indices, update_dims);
- xla::ComputationDataHandle sum = builder->Add(current, update);
+ xla::XlaOp sum = builder->Add(current, update);
return builder->DynamicUpdateSlice(operand, sum, start_indices);
}
@@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument("TensorArray size must be >= 0"));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
- xla::ComputationDataHandle value;
+ xla::XlaOp value;
TensorShape shape;
if (element_shape_.IsFullyDefined()) {
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
value = b->Broadcast(zero, ta_shape.dim_sizes());
}
@@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(2);
@@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
- xla::ComputationDataHandle value = ctx->Input(2);
- xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
+ xla::XlaOp value = ctx->Input(2);
+ xla::XlaOp flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL);
auto update = b->Reshape(value, slice_shape.dim_sizes());
- xla::ComputationDataHandle written =
+ xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
OP_REQUIRES_OK(ctx, resource->SetValue(written));
@@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel {
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::ComputationDataHandle read =
- b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector value_shape(slice_shape.begin() + 1, slice_shape.end());
@@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
auto indices = ctx->Input(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
// Look for the case where the gather takes a simple slice from the
// tensor array (0, 1, 2, 3, 4, ..., N)
@@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
}
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
ctx,
XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0,
@@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape value_shape = ctx->InputShape(2);
@@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel {
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
errors::InvalidArgument("indices must be rank 1"));
const int num_indices = indices_shape.dim_size(0);
- const xla::ComputationDataHandle indices = ctx->Input(1);
+ const xla::XlaOp indices = ctx->Input(1);
- xla::ComputationDataHandle ta = resource->value();
- const xla::ComputationDataHandle value = ctx->Input(2);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ const xla::XlaOp value = ctx->Input(2);
+ const xla::XlaOp flow = ctx->Input(3);
// Look for the case where the scatter is for each sub-tensor in order. The
// tensor array implementation allows for this to be a straight addition.
@@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
auto ta_dims = ta_shape.dim_sizes();
std::vector shape(ta_dims.begin() + 1, ta_dims.end());
@@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel {
TensorShape elem_shape = value_shape;
elem_shape.set_dim(0, length);
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel {
"TensorArray's size is not equal to the size of lengths (",
lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
- const xla::ComputationDataHandle value = ctx->Input(1);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ const xla::XlaOp value = ctx->Input(1);
+ const xla::XlaOp flow = ctx->Input(3);
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
errors::InvalidArgument("mismatched element count ",
@@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 9aefcd4fc7f94a1dba1c56273c55d0b98fbbfaf2..e91075196bd8414939888e22b5483ad637487af6 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -112,7 +112,7 @@ class TileOp : public XlaOpKernel {
flattened.push_back(i);
flattened.push_back(i + output_shape.size());
}
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
ctx->builder()->Reshape(broadcasted, flattened, output_shape);
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index f750f7003be288461f5f10455e58932d1b4e4524..34caefa050c0d58f5f7bad557286b6ed64b996ad 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -30,8 +30,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaOp handle;
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(1);
TensorShape var_shape;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
@@ -63,12 +63,12 @@ class ResourceApplyMomentum : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
- xla::ComputationDataHandle var, accum;
+ xla::XlaOp var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
@@ -93,9 +93,9 @@ class ResourceApplyMomentum : public XlaOpKernel {
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(2);
- xla::ComputationDataHandle grad = ctx->Input(3);
- xla::ComputationDataHandle momentum = ctx->Input(4);
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp momentum = ctx->Input(4);
accum = b->Add(b->Mul(accum, momentum), grad);
if (use_nesterov_) {
@@ -121,12 +121,12 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
- xla::ComputationDataHandle var, accum;
+ xla::XlaOp var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
@@ -146,8 +146,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(2);
- xla::ComputationDataHandle grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp grad = ctx->Input(3);
accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
var = b->Sub(
@@ -168,7 +168,7 @@ class ResourceApplyAdam : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape var_shape, m_shape, v_shape;
- xla::ComputationDataHandle var, m, v;
+ xla::XlaOp var, m, v;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
@@ -213,25 +213,25 @@ class ResourceApplyAdam : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle beta1_power = ctx->Input(3);
- xla::ComputationDataHandle beta2_power = ctx->Input(4);
- xla::ComputationDataHandle lr = ctx->Input(5);
- xla::ComputationDataHandle beta1 = ctx->Input(6);
- xla::ComputationDataHandle beta2 = ctx->Input(7);
- xla::ComputationDataHandle epsilon = ctx->Input(8);
- xla::ComputationDataHandle grad = ctx->Input(9);
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp beta2_power = ctx->Input(4);
+ xla::XlaOp lr = ctx->Input(5);
+ xla::XlaOp beta1 = ctx->Input(6);
+ xla::XlaOp beta2 = ctx->Input(7);
+ xla::XlaOp epsilon = ctx->Input(8);
+ xla::XlaOp grad = ctx->Input(9);
// alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
- xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::ComputationDataHandle alpha =
+ xla::XlaOp alpha =
b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)),
b->Sub(one, beta1_power));
m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1)));
@@ -255,12 +255,12 @@ class ResourceApplyRMSProp : public XlaOpKernel {
explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(3);
TensorShape var_shape, ms_shape, mom_shape;
- xla::ComputationDataHandle var, ms, mom;
+ xla::XlaOp var, ms, mom;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
@@ -297,11 +297,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(3);
- xla::ComputationDataHandle rho = ctx->Input(4);
- xla::ComputationDataHandle momentum = ctx->Input(5);
- xla::ComputationDataHandle epsilon = ctx->Input(6);
- xla::ComputationDataHandle grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input(3);
+ xla::XlaOp rho = ctx->Input(4);
+ xla::XlaOp momentum = ctx->Input(5);
+ xla::XlaOp epsilon = ctx->Input(6);
+ xla::XlaOp grad = ctx->Input(7);
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -320,16 +320,16 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::ComputationDataHandle new_ms = b->Add(
+ xla::XlaOp new_ms = b->Add(
ms,
b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms),
b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
- xla::ComputationDataHandle new_mom =
+ xla::XlaOp new_mom =
b->Add(b->Mul(mom, momentum),
b->Mul(b->Mul(grad, lr),
b->Pow(b->Add(new_ms, epsilon),
XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::ComputationDataHandle new_var = b->Sub(var, new_mom);
+ xla::XlaOp new_var = b->Sub(var, new_mom);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
@@ -341,10 +341,10 @@ REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape var_shape, accum_shape, linear_shape;
- xla::ComputationDataHandle var, accum, linear;
+ xla::XlaOp var, accum, linear;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
@@ -399,12 +399,12 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
errors::InvalidArgument("lr_power is not a scalar: ",
lr_power_shape.DebugString()));
- xla::ComputationDataHandle grad = ctx->Input(3);
- xla::ComputationDataHandle lr = ctx->Input(4);
- xla::ComputationDataHandle l1 = ctx->Input(5);
- xla::ComputationDataHandle l2 = ctx->Input(6);
- xla::ComputationDataHandle l2_shrinkage;
- xla::ComputationDataHandle lr_power;
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaOp l2_shrinkage;
+ xla::XlaOp lr_power;
if (has_l2_shrinkage) {
l2_shrinkage = ctx->Input(7);
lr_power = ctx->Input(8);
@@ -421,26 +421,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
// var = (linear_clipped - linear) / quadratic
// accum = new_accum
- xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
- xla::ComputationDataHandle grad_to_use;
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
+ xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var)));
} else {
grad_to_use = grad;
}
- xla::ComputationDataHandle new_accum =
- b->Add(accum, b->Pow(grad_to_use, two));
- xla::ComputationDataHandle new_accum_lr_pow =
- b->Pow(new_accum, b->Neg(lr_power));
- xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
+ xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two));
+ xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power));
+ xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
linear = b->Add(
linear,
b->Sub(grad_to_use,
b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var)));
- xla::ComputationDataHandle linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
- xla::ComputationDataHandle quadratic =
- b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
+ xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
+ xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
var = b->Div(b->Sub(linear_clipped, linear), quadratic);
accum = new_accum;
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 7cb47f908d4ff43f455f1e77c53cd3cc956579ee..a4f50f52ebe8b1ed7df862996d64e135ea1d0ac5 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -33,9 +33,9 @@ namespace {
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
void Compile(XlaOpKernelContext* ctx) { \
- xla::ComputationBuilder* b = ctx->builder(); \
- xla::ComputationDataHandle x = ctx->Input(0); \
- xla::ComputationDataHandle y = COMPUTATION; \
+ xla::XlaBuilder* b = ctx->builder(); \
+ xla::XlaOp x = ctx->Input(0); \
+ xla::XlaOp y = COMPUTATION; \
ctx->SetOutput(0, y); \
} \
}; \
@@ -124,9 +124,8 @@ XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
-static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
- DataType dtype,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& x) {
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
@@ -148,9 +147,8 @@ XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
-static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
- DataType dtype,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& x) {
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
}
@@ -162,20 +160,18 @@ XLAJIT_MAKE_UNARY(Sinh,
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-static xla::ComputationDataHandle Softplus(
- xla::ComputationBuilder* b, DataType dtype,
- const xla::ComputationDataHandle& features) {
- xla::ComputationDataHandle threshold =
- b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
- XlaHelpers::FloatLiteral(b, dtype, 2.0));
+static xla::XlaOp Softplus(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& features) {
+ xla::XlaOp threshold = b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
+ XlaHelpers::FloatLiteral(b, dtype, 2.0));
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
- xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold));
+ xla::XlaOp too_large = b->Gt(features, b->Neg(threshold));
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
- xla::ComputationDataHandle too_small = b->Lt(features, threshold);
- xla::ComputationDataHandle features_exp = b->Exp(features);
- xla::ComputationDataHandle output = b->Select(
+ xla::XlaOp too_small = b->Lt(features, threshold);
+ xla::XlaOp features_exp = b->Exp(features);
+ xla::XlaOp output = b->Select(
too_large, features,
b->Select(too_small, features_exp,
b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype)))));
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 71173f5aead47702f0ed9e95b827a6fefd9b7efd..6109db8e89e5ee67e0635d26e258bfe7cb70a15d 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -48,7 +48,7 @@ class ReadVariableOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
ctx->SetOutput(0, handle);
@@ -74,7 +74,7 @@ class AssignAddVariableOp : public XlaOpKernel {
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Add(handle, ctx->Input(1));
@@ -90,7 +90,7 @@ class AssignSubVariableOp : public XlaOpKernel {
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Sub(handle, ctx->Input(1));
@@ -105,19 +105,19 @@ class ResourceGatherOp : public XlaOpKernel {
public:
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
DataType type = ctx->expected_output_dtype(0);
TensorShape resource_shape;
- xla::ComputationDataHandle resource_handle;
+ xla::XlaOp resource_handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
&resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
/*axis=*/0, /*indices_are_nd=*/false, type, index_type,
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 0ff1b65ae9179d506e453f98097cd88083eb2be7..5467c5d9946846ff9f14ce9c5aac9e2be4b9d6ab 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -101,7 +101,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
ctx, MakeXlaCompilerArgumentsFromInputs(
ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
VLOG(1) << "Compiling body";
@@ -234,7 +234,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
- std::vector inputs(num_inputs);
+ std::vector inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (ctx->input_type(input_num) == DT_RESOURCE) {
@@ -246,24 +246,24 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
}
- xla::ComputationDataHandle init = builder->Tuple(inputs);
+ xla::XlaOp init = builder->Tuple(inputs);
VLOG(1) << "Building while loop";
// Wraps the condition in a computation that unpacks the output tuple.
- xla::Computation cond_wrapper;
+ xla::XlaComputation cond_wrapper;
{
- std::unique_ptr cb =
+ std::unique_ptr cb =
builder->CreateSubBuilder("cond_wrapper");
auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
auto outputs = cb->Call(*cond.computation, {inputs});
cb->GetTupleElement(outputs, 0);
- xla::StatusOr result = cb->Build();
+ xla::StatusOr result = cb->Build();
OP_REQUIRES_OK(ctx, result.status());
cond_wrapper = std::move(result.ValueOrDie());
}
- xla::ComputationDataHandle while_result =
+ xla::XlaOp while_result =
builder->While(cond_wrapper, *body.computation, init);
// Sets non-variable outputs.
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 12fdfb605d667bf2cc96e79e84954b89229a7340..04ad3694a0c0df9d43c706d428c3b8715e5ff8ca 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -62,9 +62,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -82,8 +82,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -101,9 +101,9 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -122,8 +122,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -161,8 +161,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 798f0fa78055e800038e8bf41b4f410b670be7dd..526694d5a0c7124e1696f34b516f3b202462bc19 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -25,24 +25,22 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x, bool conjugate_y) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape,
- builder->GetShape(x));
- TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape,
- builder->GetShape(y));
+xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x,
+ bool conjugate_y) {
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
// Check that both tensors have the same number of dimensions. There must be
// at least two (the batch dimensions can be empty).
- if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
+ if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
return errors::InvalidArgument(
"Arguments to BatchedDot have different ranks: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs. ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs. ",
+ xla::ShapeUtil::HumanString(y_shape));
}
- const int ndims = xla::ShapeUtil::Rank(*x_shape);
+ const int ndims = xla::ShapeUtil::Rank(x_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to BatchedDot must have rank >= 2: ", ndims);
@@ -52,46 +50,46 @@ xla::StatusOr BatchDot(
// valid.
std::vector batch_dimension_numbers;
for (int i = 0; i < ndims - 2; ++i) {
- if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
+ if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
return errors::InvalidArgument(
"Dimension ", i, " of inputs to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs ",
+ xla::ShapeUtil::HumanString(y_shape));
}
batch_dimension_numbers.push_back(i);
}
int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
- if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
+ if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
return errors::InvalidArgument(
"Dimensions ", x_inner_dim, " and ", y_inner_dim,
" of arguments to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
- " vs. ", xla::ShapeUtil::HumanString(*y_shape),
+ xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(y_shape),
" transpose: ", transpose_y);
}
// Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
- xla::ShapeUtil::HasZeroElements(*y_shape)) {
+ if (xla::ShapeUtil::HasZeroElements(x_shape) ||
+ xla::ShapeUtil::HasZeroElements(y_shape)) {
std::vector dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
- dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
+ dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
}
int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape->dimensions(x_outer_dim));
- dimensions.push_back(y_shape->dimensions(y_outer_dim));
+ dimensions.push_back(x_shape.dimensions(x_outer_dim));
+ dimensions.push_back(y_shape.dimensions(y_outer_dim));
return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
+ builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())),
dimensions);
}
- if (x_shape->element_type() == xla::C64 && conjugate_x) {
+ if (x_shape.element_type() == xla::C64 && conjugate_x) {
x = builder->Conj(x);
}
- if (y_shape->element_type() == xla::C64 && conjugate_y) {
+ if (y_shape.element_type() == xla::C64 && conjugate_y) {
y = builder->Conj(y);
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index b230e885f10f45a78cdd6e455da3ba55ce589b96..1acc72033b05e73b0f5f88907df20cde5cfffbf0 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -43,10 +43,10 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::StatusOr BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x = false, bool conjugate_y = false);
+xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x = false,
+ bool conjugate_y = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 203365e2ab07e0da1abfac5452a8ec41a4ddf406..83e73827862ca26a1a51bed72ab87768854c1e71 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -47,23 +47,21 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::StatusOr CholeskyUnblocked(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape,
- builder->GetShape(a));
- const int n_dims = xla::ShapeUtil::Rank(*a_shape);
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape->dimensions()),
+xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder,
+ const xla::XlaOp& a) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - 2);
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
// Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice loop_vars,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr> {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr> {
xla::Shape col_shape;
xla::Shape row_shape;
for (int64 d : major_dims) {
@@ -72,12 +70,12 @@ xla::StatusOr CholeskyUnblocked(
}
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
- row_shape.set_element_type(a_shape->element_type());
+ row_shape.set_element_type(a_shape.element_type());
auto mask_zeros_row = Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
- col_shape.set_element_type(a_shape->element_type());
+ col_shape.set_element_type(a_shape.element_type());
auto mask_zeros_col = Zeros(body_builder, col_shape);
std::vector mask_vector(n);
@@ -101,7 +99,7 @@ xla::StatusOr CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
{i, i}, {1, 1}));
// np.dot(row, np.swapaxes(row, -1, -2))
- xla::ComputationDataHandle diag_dot;
+ xla::XlaOp diag_dot;
TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
/*transpose_x=*/false,
/*transpose_y=*/true));
@@ -109,7 +107,7 @@ xla::StatusOr CholeskyUnblocked(
// np.swapaxes(row, -1, -2)))
auto l_ii = body_builder->Pow(
body_builder->Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape->element_type(), 0.5));
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1));
@@ -140,7 +138,7 @@ xla::StatusOr CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
body_builder, body_l, l_ii, {i, i}));
- return std::vector{body_a, body_l};
+ return std::vector{body_a, body_l};
};
TF_ASSIGN_OR_RETURN(
@@ -152,22 +150,20 @@ xla::StatusOr CholeskyUnblocked(
} // namespace
-xla::StatusOr Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape,
- builder->GetShape(a));
- const int ndims = xla::ShapeUtil::Rank(*a_shape);
+xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to Cholesky must have rank >= 2: ", ndims);
}
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"Arguments to Cholesky must be square matrices: ",
- xla::ShapeUtil::HumanString(*a_shape));
+ xla::ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
@@ -179,7 +175,7 @@ xla::StatusOr Cholesky(
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
// execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 17da8d8b22d107701ce768ac945c1404df6d47e8..20fca7969ece2729a44933fd3ef3f87230ab6cad 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -30,9 +30,8 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::StatusOr Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size = 256);
+xla::StatusOr