diff --git a/RELEASE.md b/RELEASE.md index 763ef3b279dde209ed387534032deae40a33a9e4..bdc23795e55800a885386ab8d63b032fa4979149 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,9 @@ +# Release 1.10.1 +## Bug Fixes and Other Changes + +* `tf.keras`: + * Fixing keras on Cloud TPUs. No new binaries will be built for Windows. + # Release 1.10.0 ## Major Features And Improvements diff --git a/configure.py b/configure.py index 361bd4764dc5c1900be7378f51c00aedf6f2ce41..52a513779e601482d673297ed08e43133c5ad3c7 100644 --- a/configure.py +++ b/configure.py @@ -852,7 +852,7 @@ def set_tf_cuda_version(environ_cp): # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_path_full)) + (tf_cuda_version, cuda_toolkit_paths_full)) environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 173bbea596a4276559f5cd67824e5cc75313985c..79811ceae57e0bddeb2a6f32bad7003e14e23422 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index c046bd66cda593e4feaf02f9e8068d4b59cf3e19..c195c9e01ca920c7234499b6e1d5e9cbf24056f3 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a2c5a42c11361779de61b515e0f08dcc45e609b9..f68f8a3e90a971b5e4a024feaf26ba498afc48da 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/base64.h" diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2de5ad63d11eb8979ae64c26b91366f..0717e7dd4b358d6c212070374bcc3fd2f91ed0ab 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index bd270045e3800705a53798220d80022f09faabe2..cf5c04ac4bdff73b76a365c346f7db60ce2d8197 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,7 +20,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 8d94f5495cdb64fdb8a453c5b591564dd4990dcf..7a0932d44d405de0f2edf072f4760126bff36719 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -231,6 +231,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index dd2b151098f2054571ac32b8b506cbc00659588a..7ac90fb8a9c73bdbc149f263d7d229a6514769f8 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) { string hlo_profile_as_string = xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), /*clock_rate_ghz=*/1.0); - VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; + + // Strip away identifier details from the profile string to avoid this test + // being a change detector for xla internals. Identifiers such as '%dot.0.7' + // just become '%dot'. + RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); + VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = absl::StrSplit(hlo_profile_as_string, '\n'); @@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto tuple_profile_line = HasSubstr( - "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " + "f32[2,2]{1,0} %add)"); + auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 1c9d30d7b0865ac9504d31a01b9c56c61e626f77..b95b063348c5cdfdcaed635ba527e9f0bfd6092d 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -93,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a989f15a1c8681c6b0118e4d45ec237ca3ae81cb..f4e1bc5e8390107df8ea1a5f8eb6b0193082d3fd 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,6 +265,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, @@ -362,6 +363,7 @@ cc_library( "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -370,6 +372,7 @@ cc_library( "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -396,6 +399,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -474,6 +478,7 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], @@ -489,7 +494,9 @@ tf_cc_test( "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index ae7a22f4516fc6c87c0c555214eacac71f2ea0d7..e0632ff7e48ccea99d469f62ec9d0a3fe8295024 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546fec72048485d30966f31b24e44b1245..90354a801afb26b003e00c4529069fdc61bbca32 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..97ef8cd3cb3fba54259fc413e0a3d3e75a89c431 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,360 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector args; + std::vector retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector> data_inputs(num_inputs); + gtl::FlatSet control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + gtl::FlatSet control_outputs; + std::vector>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + def.set_op("XlaLaunch"); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..99e9dfd598f29697dd009aa32f5317ed3dc647ae --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + + namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f643fb0cfe136caba42272d72f3972ec63a94bf3 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,346 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map index = BuildNodeIndex(*graph); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0"), std::initializer_list{}, + std::initializer_list{a, b, c, d}, + std::initializer_list{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc8a92047d495d057f59e2281710e678..3770eea6d09bb8ce7d83ddda253e5559ddc42e39 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,12 +15,31 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +// PRE_PLACEMENT passes: + +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 44caf0be5274fdd7bfd810988dad2522f57529b7..e6cc6e52ae537c23d18dc2d3fb94b40a5d23b1a5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -443,7 +443,7 @@ Status FindCompilationCandidates( !registration->requires_compilation) { const OpDef* op_def; TF_RETURN_IF_ERROR( - OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // We need to be able to constant fold the nodes in // compile_time_const_nodes given constant inputs (required by XLA) and diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3c133b95915ea88e0bf99dbb8661452..c59770a4c8d4a5cb8508a928677f34aeb3d6acf5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -633,7 +634,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); @@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) { EXPECT_EQ(clusters["shape"], ""); } +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ffd5dae55983e601b8d2d65af6a6d54c..1a29c3caabe382b6c29244539575c5ba4e975f2f 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 94c96ac7c568955ebd48eb6a4ec64210fafec331..ba218f3315d2607c47342fdade0403678faa2362 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,7 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ -#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6d4160a968111f6fdffcba1d6ad9698d8d0ea79e..af83c792e5e11d8596c521c6a3aed332a1f42e5b 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name), - device, cpu_tensor, done); + manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, + done); } void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1effd6628f8764e7dde794fb9a7da4a7aca2e895..df824212948ac96a5df5228cecd9a8c864bbec9a 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -111,12 +110,9 @@ class XlaDeviceContext : public DeviceContext { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const override; - // TODO(rlahaye): Replace StringPiece with absl::string_view when the - // StringPiece->absl::string_view change is rolled forward. void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, // non-ABSL OK - Device* device, Tensor* cpu_tensor, - StatusCallback done) override; + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 050d827a093bc498ba47810d9fa959459ca911fc..97ed554171f343991adcccf8c399756d06b13c5f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -277,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -581,6 +582,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -1197,7 +1199,7 @@ tf_xla_py_test( tf_xla_py_test( name = "xla_ops_test", - size = "small", + size = "medium", srcs = ["xla_ops_test.py"], disabled_backends = ["cpu_ondemand"], deps = [ diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index df0f21471a1c67e69e037f6409bcab1297d3399d..058576b3d4b695209952158769162bb24e7ccfce 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -98,7 +98,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -140,7 +140,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 7b114d4f85d3a5cadc6af25b55c5a21f90d2a768..a76f136736f7c15788fb789dcb92bbd6becd8582 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -4,88 +4,97 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") def all_backends(): - b = ["cpu"] + plugins.keys() - if cuda_is_configured(): - return b + ["gpu"] - else: - return b + b = ["cpu"] + plugins.keys() + if cuda_is_configured(): + return b + ["gpu"] + else: + return b -def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, - disabled_backends=None, **kwargs): - """Generates py_test targets, one per XLA backend. +def tf_xla_py_test( + name, + srcs = [], + deps = [], + tags = [], + data = [], + main = None, + disabled_backends = None, + **kwargs): + """Generates py_test targets, one per XLA backend. - This rule generates py_test() targets named name_backend, for each backend - in all_backends(). The rule also generates a test suite with named `name` that - tests all backends for the test. + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. - For example, the following rule generates test cases foo_test_cpu, - foo_test_gpu, and a test suite name foo_test that tests both. - tf_xla_py_test( - name="foo_test", - srcs="foo_test.py", - deps=[...], - ) + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - tags: Tags to apply to the generated targets. - data: Data dependencies of the target. - main: Same as py_test's main attribute. - disabled_backends: A list of backends that should not be tested. Supported - values include "cpu" and "gpu". If not specified, defaults to None. - **kwargs: keyword arguments passed onto the generated py_test() rules. - """ - if disabled_backends == None: - disabled_backends = [] + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + disabled_backends: A list of backends that should not be tested. Supported + values include "cpu" and "gpu". If not specified, defaults to None. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if disabled_backends == None: + disabled_backends = [] - enabled_backends = [b for b in all_backends() if b not in disabled_backends] - test_names = [] - for backend in enabled_backends: - test_name = "{}_{}".format(name, backend) - backend_tags = ["tf_xla_{}".format(backend)] - backend_args = [] - backend_deps = [] - backend_data = [] - if backend == "cpu": - backend_args += [ - "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" - ] - elif backend == "gpu": - backend_args += [ - "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" - ] - backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_args += ["--test_device=" + plugins[backend]["device"], - "--types=" + plugins[backend]["types"]] - backend_tags += plugins[backend]["tags"] - backend_args += plugins[backend]["args"] - backend_deps += plugins[backend]["deps"] - backend_data += plugins[backend]["data"] - else: - fail("Unknown backend {}".format(backend)) + enabled_backends = [b for b in all_backends() if b not in disabled_backends] + test_names = [] + for backend in enabled_backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + ] + elif backend == "gpu": + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + ] + backend_tags += ["requires-gpu-sm35"] + elif backend in plugins: + backend_args += [ + "--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"], + ] + backend_tags += plugins[backend]["tags"] + backend_args += plugins[backend]["args"] + backend_deps += plugins[backend]["deps"] + backend_data += plugins[backend]["data"] + else: + fail("Unknown backend {}".format(backend)) - native.py_test( - name=test_name, - srcs=srcs, - srcs_version="PY2AND3", - args=backend_args, - main="{}.py".format(name) if main == None else main, - data=data + backend_data, - deps=deps + backend_deps, - tags=tags + backend_tags, - **kwargs - ) - test_names.append(test_name) - native.test_suite(name=name, tests=test_names) + native.py_test( + name = test_name, + srcs = srcs, + srcs_version = "PY2AND3", + args = backend_args, + main = "{}.py".format(name) if main == None else main, + data = data + backend_data, + deps = deps + backend_deps, + tags = tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name = name, tests = test_names) -def generate_backend_suites(backends=[]): - """Generates per-backend test_suites that run all tests for a backend.""" - if not backends: - backends = all_backends() - for backend in backends: - native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) +def generate_backend_suites(backends = []): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend]) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb54c5d8ecdedc7bb346e89765f2adf35..2d225ad226cac368042b95eae8fc29e6fd8e82e0 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 9222db4b7ebf020c8cee1c0af81e05129fb33c4d..c61965b97fc142ce452cf28def8c937f692d2f84 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(xla_test.XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): - def _testMatrixBandPart(self, dtype, shape): - with self.cached_session(): - batch_shape = shape[:-2] - mat = np.ones(shape).astype(dtype) - batch_mat = np.tile(mat, batch_shape + [1, 1]) - for lower in -1, 0, 1, shape[-2] - 1: - for upper in -1, 0, 1, shape[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape: - band_np = np.tile(band_np, batch_shape + [1, 1]) - - placeholder = array_ops.placeholder(dtype) - with self.test_scope(): - band = array_ops.matrix_band_part( - placeholder, - constant_op.constant(lower, dtype=dtypes.int32), - constant_op.constant(upper, dtype=dtypes.int32)) - feed_dict = {placeholder: batch_mat} - self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) - - def testMatrixBandPart(self): + @parameterized.parameters( + { + 'batch_shape': [], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 7 + }, + ) + def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: - for batch_shape in [[], [2,], [1, 3, 2]]: - for rows in 1, 2, 7: - for cols in 1, 2, 7: - self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + with self.cached_session(): + mat = np.ones(batch_shape + [rows, cols]).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, rows - 1: + for upper in -1, 0, 1, cols - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py index 84c67779400f7a800bd88abc32d95058a6c0904d..96e0b074754032dd64c479b5e587b664ff066e2b 100644 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): ('64_bit_index', dtypes.int64)) def testBasic(self, index_dtype): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): shape = constant_op.constant([3, 2], dtype=index_dtype) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3f928a1beabf60f3e3e5d86af7eea8bb36c375c8..1e600c44e9af66994686359eb0e1a1e52bea93fd 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3821dced638458837ecd61191e93d464dc2c1f99..ba1e3b2b4fdbb73e98105ace6571783ef780adf5 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -188,7 +189,6 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -215,7 +215,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -285,6 +284,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -480,6 +480,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -507,11 +508,23 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) +cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + cc_library( name = "functionalize_while", srcs = [ @@ -521,6 +534,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -531,6 +545,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -545,6 +560,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", @@ -595,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 0911550f1fe6e3106e9772288c688023bb80bbe3..db256e577a1f3dd38e04d102f60182023b9d43b2 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -217,10 +218,6 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { added_node_ancestorid_mapping_[node->id()] = id; } -const StateMap::CondState& StateMap::LookupState(const Node* node) const { - return *LookupCondId(node); -} - void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } string StateMap::CondStateToString(const Node* node) const { @@ -642,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If"); + NodeDefBuilder builder(name(), "If", library); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -791,7 +788,6 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -1056,7 +1052,6 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { " has no non-dead inputs."); } state_map_.MarkDead(node); - delete_nodes_.push_back(node->id()); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { const Edge* oe = *node->out_edges().begin(); @@ -1132,7 +1127,6 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { } } else if (BranchType(switch_branch) != b) { state_map_.MarkDead(dst_node); - delete_nodes_.push_back(dst_node->id()); continue; } graph_->AddEdge( @@ -1154,7 +1148,7 @@ Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) << " @ " << state_map_.AncestorStateToString(dst); - if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it"); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } return Status::OK(); } @@ -1184,23 +1178,62 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { return Status::OK(); } -void FunctionalizeCond::DeleteReachableNodes() { +void FunctionalizeCond::DeleteReachableAndDeadNodes( + const std::vector& switch_ids, const std::vector& merge_order) { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been // removed. + std::deque delete_nodes; std::vector deleted(graph_->num_node_ids(), false); // Don't try to delete source or sink nodes. deleted[graph_->kSourceId] = true; deleted[graph_->kSinkId] = true; - while (!delete_nodes_.empty()) { - int d_id = delete_nodes_.front(); - delete_nodes_.pop_front(); + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) { + Node* s = graph_->FindNodeId(s_id); + if (s == nullptr) continue; + for (const Edge* e : s->out_edges()) { + // Control outputs of switch nodes (which are unconditionally executed if + // the switch is) are not removed as they need not be part of a + // conditional. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[s_id] = true; + graph_->RemoveNode(s); + } + + // All merge nodes should have been transformed at this point and we remove + // them from the graph here. + for (Node* m : merge_order) { + for (const Edge* e : m->out_edges()) { + // Similar to control outputs of switch nodes don't remove control + // outputs of merge nodes. + // TODO(jpienaar): Check cases where output edges still exist here vs + // being removed in AddOutputEdges. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[m->id()] = true; + graph_->RemoveNode(m); + } + + // Enqueue all the dead nodes. + for (Node* n : graph_->nodes()) { + if (state_map_.IsDead(state_map_.LookupCondId(n))) { + delete_nodes.push_back(n->id()); + } + } + + while (!delete_nodes.empty()) { + int d_id = delete_nodes.front(); + delete_nodes.pop_front(); if (deleted[d_id]) continue; Node* d = graph_->FindNodeId(d_id); // Switch and Merge nodes could have been deleted already. if (d == nullptr) continue; for (const Edge* e : d->out_edges()) { - delete_nodes_.push_back(e->dst()->id()); + delete_nodes.push_back(e->dst()->id()); } deleted[d_id] = true; graph_->RemoveNode(d); @@ -1274,7 +1307,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { } TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); - if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + if (VLOG_IS_ON(4)) DumpGraphWithCondState("id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); @@ -1312,11 +1345,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } - // All remaining Switch nodes are not reachable from a Merge node and - // removed. This is to account for dead Switch nodes. - for (int s_id : switch_ids) delete_nodes_.push_back(s_id); - for (Node* m : merge_order) delete_nodes_.push_back(m->id()); - DeleteReachableNodes(); + DeleteReachableAndDeadNodes(switch_ids, merge_order); return Status::OK(); } @@ -1331,8 +1360,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " - << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name), - *graph_, library_); + << dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_cond_", name), *graph_, + library_); } Status FunctionalizeCond::Functionalize(Graph* graph, diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 28301150ea506e09c0b1addcd8ca77edee905275..189980894073b1da1a12d1c284536336eb920900 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -91,10 +91,6 @@ class StateMap { // Resets the AncestorId for a given node. void ResetAncestorId(const Node* node, AncestorId id); - // Returns the CondState for a Node. - // REQUIRES: node has a non-empty CondState. - const CondState& LookupState(const Node* node) const; - // Marks `node` as dead. void MarkDead(const Node* node); @@ -221,8 +217,10 @@ class FunctionalizeCond { // nesting depth. void SortMergeNodes(std::vector* merge_order); - // Deletes all nodes in/consumers of `delete_nodes_`. - void DeleteReachableNodes(); + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& switch_ids, + const std::vector& merge_order); // Member used to unique the CondState to a unique CondId (AncestorState to a // unique AncestorId) and keep track of CondState/CondId @@ -232,9 +230,6 @@ class FunctionalizeCond { // Mapping from merge nodes to predicate. std::unordered_map merge_to_predicate_; - // Nodes to be deleted. - std::deque delete_nodes_; - FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5932be4e525dec11a8f3c59bb85e0449e76e79c0..f792c520329039c8da63d07ea27fa1c403f5c67d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -68,4 +73,146 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map* canonicalized_name_to_new_name) { + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + const FunctionDef& fdef = body->fdef; + + // If any node has associated functions, functionalize them first. + // Gather nodes with associated functions first, because rewriting those nodes + // might involve node deletion/addition. Avoid modifying nodes while iterating + // it. + std::vector>> + nodes_to_associated_functions; + for (auto* n : body->graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, flr); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (auto iter : nodes_to_associated_functions) { + Node* n = iter.first; + auto associated_functions = iter.second; + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + string new_name; + if (iter != canonicalized_name_to_new_name->end()) { + // If we already functionalized this function, skip functionalization + // but still rewrite the node. + new_name = iter->second; + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + body->graph, n, fld, associated_function, new_name)); + } + } + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *body->graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *body->graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); + + // Copy signature and ret from original FunctionDef. + *functionalized_fdef.mutable_signature() = fdef.signature(); + *functionalized_fdef.mutable_ret() = fdef.ret(); + functionalized_fdef.mutable_signature()->set_name(new_func_name); + + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ + {"TPUCompile", "function"}, + {"XlaLaunch", "function"}, + }; + std::map canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; + } + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 55600f2a8b5302cef26b9be4ccd0f8804476a17a..ba99205640ccdc83a3a4d50e3ec474907894a835 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -32,6 +33,14 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (If/While). +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000000000000000000000000000000000..a10a9d0499457bbc0383ea3a8c678f153e21894b --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c068a4110c0bb14282379eb7a3cbdae4e80ddbd6..c3841f996f801e855da75b23f01d41674ec51c4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" @@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, then_fn, - else_fn, {DT_INT32}); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); - // TODO(jpienaar): Create wrapper for IfOp. - for (NodeDef& n : *expected.mutable_node()) { - if (n.op() == "XlaIf") n.set_op("If"); - } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) { Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, NameAttrList* body) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaWhile") { + if (node.op() == "While") { const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); *cond = *result; @@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, return Status::OK(); } } - return errors::NotFound("No XlaWhile node found in graph"); + return errors::NotFound("No While node found in graph"); } // Graph: @@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_ASSERT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); GraphDef expected; @@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = - ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 7f45e3bffa0eec8fa0d879c0d8011545221acb3d..7c3ad448ef546dd1ab2640a57d7d1d73ca3768ad 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + NodeDefBuilder builder(frame->loop_cond->name(), "While", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -653,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, // There should be no cycle at this point, since while loops have been removed // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. + // Check that the newly added While nodes don't feed into themselves. for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { + if (node->def().op() == "While") { TF_RETURN_WITH_CONTEXT_IF_ERROR( CheckNodeNotInCycle(node, graph->num_node_ids()), "Functionalizing loop failed."); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index bc2e6405596a7c61e64d0f4e3152e80ff562f2a0..c019a28e892ff89f559ddbec2360d6caa9c1808f 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -81,7 +80,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index ab7cac7100d39377828462f0dee5df98a7319cc3..e9f02201cf6bed5495dff7dff76c5bafe7771516 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -55,17 +55,17 @@ namespace tensorflow { // op registration infrastructure instead of FunctionLibraryRuntime. class GraphCompiler { public: - GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, - Graph* graph, FunctionLibraryRuntime* flib, + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, ScopedStepContainer* step_container) - : xla_context_(xla_context), - device_(device), + : device_(device), graph_(graph), flib_(flib), step_container_(step_container) {} - // Compiles the graph. The results are written in `xla_context` that is passed - // into the compiler. + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. Status Compile(); private: @@ -82,7 +82,6 @@ class GraphCompiler { // using `compiler_`. Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); - XlaContext* xla_context_; XlaCompilationDevice* device_; Graph* graph_; FunctionLibraryRuntime* flib_; diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index df17da4c1ca07053cf63757f1acf2b1a3735e705..0d9a768a6f47a823020498315d4c40b5854fdbe7 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -66,6 +66,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + if (DataTypeIsUnsigned(dtype)) { + return xla::Div(x, y); + } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f4106051043859a6786705009d76b02a64cd3ff1..0ae23aa6dfe49048ac5cb8ae00c12432b2e2a2fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; + std::vector partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index a3389d5b905bf3ee15744ab4fcee193d312e2ae0..4af1e8b44cbbd02d8e3ea5e42d841c92288b5d56 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - VLOG(3) << "DynamicUpdateSliceOp::Compile"; + DataType index_type = ctx->InputType("indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); - DataType index_type = input_type(2); - OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape update_shape = ctx->InputShape(1); - const TensorShape index_shape = ctx->InputShape(2); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape update_shape = ctx->InputShape("update"); + const TensorShape index_shape = ctx->InputShape("indices"); OP_REQUIRES( ctx, @@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = - xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = xla::DynamicUpdateSlice( + ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); ctx->SetOutput(0, result); } }; REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); +class DynamicSliceOp : public XlaOpKernel { + public: + explicit DynamicSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType index_type = ctx->InputType("start_indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); + CHECK(index_type == ctx->InputType("size_indices")); + + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape start_indices_shape = ctx->InputShape("start_indices"); + const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(start_indices_shape) && + start_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "start_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and start_indices has shape ", + start_indices_shape.DebugString())); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(size_indices_shape) && + size_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "size_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and size_indices has shape ", + size_indices_shape.DebugString())); + + std::vector size_indices; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + xla::XlaOp result = xla::DynamicSlice( + ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), + DynamicSliceOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11e8ecb688f8e773ef4b286eafe68f4f..3d81ae9eb89a80e5b89b180ad77521c5ed15e79d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c26784852472061ffead03cfe7431f8b8ba0e555..804671fbc75b0a5a6e04b204822b6f084013cd8b 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0(static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0( - static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 7dc16b5a46791b81eef2c572736e1a1c7969b203..15f4c38da29507da9e092c1d5725b5f95a81d1b9 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,51 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector int64_values = {1, 2, 3}; - std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int64_values)); - } + std::vector int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); +} + +template +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types, std::pair, + std::pair, std::pair, + std::pair>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector int32_values = {10, 11}; - std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int32_values)); + Tensor host_tensor; + std::vector int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector qint32_values = {10, 11}; - test::ExpectTensorEqual(host_tensor, - test::AsTensor(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, + &host_tensor) + .ok()); + std::vector qint_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 68cfdc178563ceeee1fb18cd0c890f115c1a8587..02363500efe1a11348eaf7d8b99da76307acdd3c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -105,6 +105,35 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaDynamicSlice") + .Input("input: T") + .Input("start_indices: Tindices") + .Input("size_indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA DynamicSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must be rank == +1, with dimension size equal to the rank of operand. + +input: A `Tensor` of type T. + +start_indices: Rank 1 tensor of N integers containing the starting indices of + the slice for each dimension. Value must be greater than or equal to zero. + +start_indices: List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less +)doc"); + REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") .Input("update: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3626de375ea9ac12e40ea5b5b591bb6d5262adbc..27dd18a9bbd5aceece41aaf61eb185acb537b3b6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) -def dynamic_slice(x, starts, sizes, name=None): - # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not - # a compile-time constant. This doesn't exactly mimic the semantics of dynamic - # slice if the slice is out of bounds. - return array_ops.slice(x, starts, sizes, name=name) - - +dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice # TODO(phawkins): generalize tf.pad to support interior padding, and then remove diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 92577b5bc82b6b0927db51b6d96f37d8c886f1c3..20f2ce2919701731ef6e90d368b67545af95e8f9 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -31,10 +30,11 @@ namespace tensorflow { } } -static gtl::FlatMap* CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap; +static gtl::FlatMap* +CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap; - auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,17 +103,17 @@ static gtl::FlatMap* CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const gtl::FlatMap& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = + static gtl::FlatMap* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap& op_infos = + const gtl::FlatMap& op_infos = GetStaticResourceOpInfoMap(); - auto it = op_infos.find(StringPiece(op.data(), op.length())); + auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; } @@ -121,7 +121,7 @@ namespace resource_op_table_internal { std::vector GetKnownResourceOps() { std::vector result; for (const auto& p : GetStaticResourceOpInfoMap()) { - result.push_back(absl::string_view(p.first)); + result.push_back(p.first); } absl::c_sort(result); return result; diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec..f31bfb45a2f4db270446eb59259969dc0ab63a8e 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map BuildNodeIndex(const Graph& graph) { + std::unordered_map index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed23f3fca0f59b131dc73152e0947b72..350a868568531c0d073e0cf600327d1ff9d62e3a 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 7dbe3a0b5816c71a2174c02b0da32f4da0e44991..b22d53805d83069052cc5e16020d6c540d618a82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98201ed398244f9e3f5ff23788135b75..ab26d939ccba75ce58609ffd71c7ccadbe90cfa8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0(10); auto y_literal = xla::LiteralUtil::CreateR0(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 211caf8736990db064c8aac817ebe0897b291f69..d6f42bac86f1ef359531d67b652d43d851d7ac02 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -25,9 +25,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr) { + if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + return false; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr) { + std::vector results; + const string& op = node.type_string(); + if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index dcddef841825f90551958954a525500a07ddeb86..6065d0bb9a3abd23b8911c5049914be8a5f23b99 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -60,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionCallNode = 0, + kFunctionAttr = 1, + }; + + // The node is a function call. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) + : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} + + // The function is an attr of the node. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, + const string& attr_name) + : type_(kFunctionAttr), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1bfaa6cab0d896ee074cfd4e2b283ae4..d00b1376620c0c9d112c7d7426758f6d3f25e86f 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index dcb455779dcc9c044303210bc81925831ae50d5e..739e47778a796815348058894b0097c3a312dbd8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -323,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, - step_container.get()); + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); @@ -743,18 +744,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 40ce9fb41c3bedfcff6c9d03299b67ca4b0a8407..72b17d04fc42eb00781e96b412465b73fb29a5c2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -619,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { auto instr1 = c1.instructions(j); auto instr2 = c2.instructions(j); instr1.clear_name(); + instr1.clear_id(); + instr1.clear_operand_ids(); instr2.clear_name(); - // The names of instructions were uniquified by the XlaBuilder, the rest - // of the fields should be identical. + instr2.clear_id(); + instr2.clear_operand_ids(); + // The names of instructions were uniquified by the XlaBuilder and the + // unique ids may be different, the rest of the fields should be + // identical. string str1, str2; + LOG(INFO) << "instr1 = " << instr1.DebugString(); + LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); instr2.AppendPartialToString(&str2); EXPECT_EQ(str1, str2); @@ -672,34 +664,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr input_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr input_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr output_read = - xla::LiteralUtil::CreateR0(42); - std::unique_ptr output_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr output_grad1 = - xla::LiteralUtil::CreateR1({0, 1}); - std::unique_ptr output_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -866,29 +850,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a simple graph that reads and writes a variable. @@ -952,20 +931,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ReturnResourceHandle) { @@ -1069,29 +1045,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1138,29 +1112,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({4, 55, 1, -3}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. @@ -1255,25 +1226,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); - } - - // Fix control edges for NoOp. - { - std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); - CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); - XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); } } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 636cb71e21b7bd675ddd4c41f5352094088b85e3..2a9eaeee146bf6d792e010df7e041f9986b2c77e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); @@ -102,8 +106,7 @@ Status XlaOpKernelContext::ConstantInput(int index, static xla::StatusOr InputIndex(XlaOpKernelContext* context, absl::string_view name) { int start, stop; - TF_RETURN_IF_ERROR(context->op_kernel().InputRange( - StringPiece(name.data(), name.length()), &start, &stop)); + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -214,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } @@ -366,8 +368,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; - TF_RETURN_IF_ERROR( - context_->input_list(StringPiece(name.data(), name.size()), &inputs)); + TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { @@ -380,8 +381,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::ConstantInputList( absl::string_view name, std::vector* outputs) { int start, stop; - TF_RETURN_IF_ERROR(op_kernel().InputRange( - StringPiece(name.data(), name.size()), &start, &stop)); + TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); @@ -615,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok()); + CHECK(context_->input(name, &tensor).ok()); return *tensor; } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 962c86d3a568322b6d7134508b3f5911f2d9b9a5..a3a0d10cc06cd4afceec728b7dbe287389099b9d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -71,6 +71,9 @@ class XlaOpKernelContext { // Returns the type of input `index`. DataType input_type(int index) const; + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + // Returns the type of input `index` as an xla::PrimitiveType. If the type // is not representable as an XLA type, sets an error status and returns // xla::PRIMITIVE_TYPE_INVALID. @@ -79,7 +82,7 @@ class XlaOpKernelContext { // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns the shape of input `name`. + // Returns the shape of input with name `name`. TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 5d53169f6898f02184f6a0facfcca6b7ad1b6738..74a4885f1f029628817f6ec3a36fcb98719d6a41 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 76e36f3c46b22742b6cf0c86e89d17899338a60f..ef70c1f8ac7e31b194dfec2ae67be6763cb77753 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -193,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 8818f813127230d3b39d4b48d874b7cfb24b8abc..5dde5b432f136c16d4e3795569499ee5de709763 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr> Client::Transfer( - const GlobalData& data, const Shape* shape_with_layout) { +StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, return Status::OK(); } -StatusOr> Client::TransferFromOutfeed( +StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -162,7 +162,7 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( +StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -177,8 +177,8 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::ComputeConstant( - const XlaComputation& computation, const Layout* output_layout) const { +StatusOr Client::ComputeConstant(const XlaComputation& computation, + const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 7960b078686e611a6439af495d266f9084992d29..6f4d33c469f1f885cfeef546e3981dc3417ef71f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -96,8 +96,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr> Transfer( - const GlobalData& data, const Shape* shape_with_layout = nullptr); + StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -122,7 +122,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr> TransferFromOutfeed( + StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -132,7 +132,7 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options = nullptr, @@ -153,7 +153,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr> ComputeConstant( + StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 6861521acc0db1d640666a6793b898a183ab6a17..25cc37edc43c28a636797c310c8882eea09a0ef3 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -76,7 +76,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr> literal_status = MakeFakeLiteral(shape); + StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. @@ -84,7 +84,7 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, tensorflow::error::UNIMPLEMENTED); return MakeFakeDataViaDeviceOrDie(shape, client); } - return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie(); + return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 4402ba8762c1538951c326c880fc3b6dd63ef0c6..f96b6c9c261a9686fb647e3da0dcc933cd1f70df 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments( HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); + *hlo_snapshot->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments( Status LocalExecutable::RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); + *hlo_snapshot->mutable_result() = literal.ToProto(); return Status::OK(); } -StatusOr> LocalExecutable::LiteralFromShapedBuffer( +StatusOr LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, backend_->BorrowStream(shaped_buffer.device_ordinal())); @@ -277,7 +275,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr> LocalClient::ShapedBufferToLiteral( +StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal, literal); } -StatusOr> LocalClient::TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal) { +StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, literal.get())); + executor, shape, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 56c3a3da023ebf124b4bd91c2c608d0cd00a2381..feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -84,8 +84,7 @@ class LocalExecutable { Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr> LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer); + StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -132,8 +131,7 @@ class LocalClient : public Client { // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr> ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. @@ -151,8 +149,8 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr> TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal); + StatusOr TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 887b970661cc95d40da76dee1c80c91ac499c9b2..95ff6432a591f87845729b180397e33a85e5e9a5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn( StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, + LookUpInstructionByHandle(root_id)); ProgramShape program_shape; - *program_shape.mutable_result() = instructions_[root_id].shape(); + *program_shape.mutable_result() = root_proto->shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, return; } - CHECK(op_handle < instructions_.size() && op_handle >= 0); - - const HloInstructionProto& instr = instructions_[op_handle]; + const HloInstructionProto& instr = + *(LookUpInstructionByHandle(op_handle).ValueOrDie()); const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); switch (opcode) { default: @@ -283,6 +283,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { // Clear data held by this builder. this->instructions_.clear(); + this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -738,7 +739,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } @@ -2285,7 +2286,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is - // also a valid denpendency order). The related ops will be added to the + // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; tensorflow::gtl::FlatSet related_calls; // Related computations. @@ -2293,14 +2294,16 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.push(root->id()); related_ops.insert(root->id()); while (!worklist.empty()) { - int64 node = worklist.front(); + int64 handle = worklist.front(); worklist.pop(); - for (int64 id : instructions_[node].operand_ids()) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(handle)); + for (int64 id : instr_proto->operand_ids()) { if (related_ops.insert(id).second) { worklist.push(id); } } - for (int64 called_id : instructions_[node].called_computation_ids()) { + for (int64 called_id : instr_proto->called_computation_ids()) { related_calls.insert(called_id); } } @@ -2308,7 +2311,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // Add related ops to the computation. for (int64 id : related_ops) { auto* instr = entry.add_instructions(); - *instr = instructions_[id]; + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, + LookUpInstructionByHandle(id)); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = StrCat(instr->name(), ".", entry.id(), ".", instr->id()); @@ -2415,11 +2420,11 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = instructions_.size(); + const int64 handle = GetUniqueId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode())); + instr.set_name(instr.opcode()); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { @@ -2437,7 +2442,8 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, *instr.mutable_sharding() = *sharding_; } - instructions_.push_back(instr); + handle_to_index_[handle] = instructions_.size(); + instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op; @@ -2467,10 +2473,16 @@ StatusOr XlaBuilder::LookUpInstruction( op.handle(), op.builder_->name(), this->name()); } - if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %d", op.handle()); + return LookUpInstructionByHandle(op.handle()); +} + +StatusOr XlaBuilder::LookUpInstructionByHandle( + int64 handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions_[op.handle()]; + return &instructions_[it->second]; } // Enqueues a "retrieve parameter value" instruction for a parameter that was diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 58e8f4e7fa3cb7dc1582f1a4f6002d8324eb3ffb..d0c59fa6f27bc265c0868734ed95a196002fbd2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -955,6 +956,8 @@ class XlaBuilder { HloInstructionProto* instr); StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -1024,6 +1027,10 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + tensorflow::gtl::FlatMap handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. @@ -2112,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0(value)); + return ConstantLiteral(LiteralUtil::CreateR0(value)); } template XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template @@ -2129,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*LiteralUtil::CreateR2(values)); + return ConstantLiteral(LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2189,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); } template XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template @@ -2207,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); } template @@ -2221,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray(values)); + LiteralUtil::CreateFromArray(values)); } template @@ -2236,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D(values)); + LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2253,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3f7635bd400c6ec87e0e3a739658272e906a72fb..5035f4198890857fcafd0156d7eaeeb4bc164322 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = absl::make_unique(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> -MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { +/* static */ StatusOr MutableLiteralBase::CreateFromProto( + const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = absl::make_unique(proto.shape()); + Literal literal(proto.shape()); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = absl::make_unique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr> LiteralBase::Broadcast( +StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); @@ -598,14 +597,14 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = absl::make_unique(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast(result->untyped_data()); + char* dest_data = static_cast(result.untyped_data()); const char* source_data = static_cast(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); @@ -627,37 +626,36 @@ StatusOr> LiteralBase::Broadcast( return std::move(result); } -StatusOr> LiteralBase::Reshape( +StatusOr LiteralBase::Reshape( absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", ShapeUtil::HumanString(shape()), - ShapeUtil::HumanString(output->shape())); + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr LiteralBase::Transpose( - absl::Span permutation) const { +Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -687,32 +685,31 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = absl::make_unique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template -std::unique_ptr LiteralBase::SliceInternal( +Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { - auto result_literal = absl::make_unique(result_shape); + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( + result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); - result_literal->Set(indices, value); + result_literal.Set(indices, value); }); return result_literal; } -std::unique_ptr LiteralBase::Slice( - absl::Span start_indices, - absl::Span limit_indices) const { +Literal LiteralBase::Slice(absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = absl::make_unique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); @@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString( namespace { template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); + auto dest_data = result_literal.template data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1208,8 +1199,7 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1217,7 +1207,7 @@ std::unique_ptr ConvertBetweenNativeTypes( template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); @@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal->data(); + absl::Span dest_data = result_literal.data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1254,8 +1244,7 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1273,9 +1262,9 @@ std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, } template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1307,12 +1296,12 @@ StatusOr> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type)); } -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1342,12 +1331,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> LiteralBase::Convert( +StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> LiteralBase::BitcastConvert( +StatusOr LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1362,17 +1351,8 @@ StatusOr> LiteralBase::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } return Convert(dest_shape.element_type()); } std::vector elements; @@ -1381,11 +1361,9 @@ StatusOr> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = absl::make_unique(); - *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( @@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case S8: + proto->set_s8s(static_cast(data().data()), + element_count()); + break; case U8: proto->set_u8s(static_cast(data().data()), element_count()); @@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case S8: { + auto s8_data = data(); + TF_RET_CHECK(proto.s8s().size() == s8_data.size()); + std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin()); + } break; case U8: { auto u8_data = data(); TF_RET_CHECK(proto.u8s().size() == u8_data.size()); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index b928cb637494dec220a0912fdea96ed25cde13ef..1e0a2ad0ddf81d6813942c77ae273e2ce24e735e 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -217,31 +217,20 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -259,24 +248,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr> Reshape( - absl::Span dimensions) const; + StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, absl::Span dimensions) const; + StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,7 +273,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose(absl::Span permutation) const; + Literal Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -293,15 +281,15 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice(absl::Span start_indices, - absl::Span limit_indices) const; + Literal Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template - std::unique_ptr Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive @@ -312,7 +300,7 @@ class LiteralBase { // initialization, then reinitialization. Conside if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -539,8 +527,8 @@ class LiteralBase { private: template - std::unique_ptr SliceInternal( - const Shape& result_shape, absl::Span start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static StatusOr CreateFromProto(const LiteralProto& proto); protected: // Returns the piece at the given ShapeIndex. @@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) { } template -std::unique_ptr LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = absl::make_unique( - ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } @@ -1157,7 +1143,7 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { bool done = false; while (!done) { const auto element = Get(input_indices); - literal->Set(output_indices, element); + literal.Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 1a64594db86af31dcc196725d4b4f2a3ad9e4746..7ad287c8973367fb04583e6911ff75e76bdf5f1e 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test { Layout layout_r3_dim0minor_; Layout layout_r4_dim0major_; Layout layout_r4_dim0minor_; - std::unique_ptr literal_r4_2x2x3x3_dim0major_; - std::unique_ptr literal_r4_2x2x3x3_dim0minor_; + Literal literal_r4_2x2x3x3_dim0major_; + Literal literal_r4_2x2x3x3_dim0minor_; }; TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -171,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - EXPECT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple.ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) { }; std::vector expected_values = {8, 9, 7, 10}; - EXPECT_EQ(literal->sparse_indices()->data(), + EXPECT_EQ(literal.sparse_indices()->data(), absl::Span(expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(literal->data(), absl::Span(expected_values)); + EXPECT_EQ(literal.data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { { /*i0=0*/ { /*i1=0*/ @@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = literal_r4_2x2x3x3_dim0major_->ToString(); + string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { { /*i0=0*/ { /*i1=0*/ @@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { }); // clang-format on std::vector> seen; - literal->EachCellAsString( + literal.EachCellAsString( [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) { auto f32_42 = LiteralUtil::CreateR0(42.0); auto f32_42_clone = LiteralUtil::CreateR0(42.0); - EXPECT_EQ(*f32_42, *f32_42); - EXPECT_EQ(*f32_42, *f32_42_clone); + EXPECT_EQ(f32_42, f32_42); + EXPECT_EQ(f32_42, f32_42_clone); auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*f32_42, *f32_123); + EXPECT_NE(f32_42, f32_123); auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_NE(*f32_42, *f64_42); + EXPECT_NE(f32_42, f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { @@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(*matrix, *matrix); - EXPECT_EQ(*matrix, *matrix_clone); - EXPECT_NE(*matrix, *matrix_different); - EXPECT_NE(*matrix, *vector_literal); - EXPECT_NE(*matrix, *scalar); - EXPECT_NE(*matrix, nil); + EXPECT_EQ(matrix, matrix); + EXPECT_EQ(matrix, matrix_clone); + EXPECT_NE(matrix, matrix_different); + EXPECT_NE(matrix, vector_literal); + EXPECT_NE(matrix, scalar); + EXPECT_NE(matrix, nil); EXPECT_EQ(nil, nil); } @@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) { auto token1 = LiteralUtil::CreateToken(); auto scalar = LiteralUtil::CreateR0(1.0); - EXPECT_EQ(*token0, *token1); - EXPECT_NE(*token0, *scalar); + EXPECT_EQ(token0, token1); + EXPECT_NE(token0, scalar); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), - *LiteralUtil::MakeTuple({token0.get()})); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0}), + LiteralUtil::MakeTuple({&token0})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&token1, &scalar})); + EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&scalar, &token1})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); - colmajor->Set({0, 0}, 1.0); - colmajor->Set({0, 1}, 2.0); - colmajor->Set({1, 0}, 3.0); - colmajor->Set({1, 1}, 4.0); + Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + colmajor.Set({0, 0}, 1.0); + colmajor.Set({0, 1}, 2.0); + colmajor.Set({1, 0}, 3.0); + colmajor.Set({1, 1}, 4.0); - auto rowmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); - rowmajor->Set({0, 0}, 1.0); - rowmajor->Set({0, 1}, 2.0); - rowmajor->Set({1, 0}, 3.0); - rowmajor->Set({1, 1}, 4.0); + Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + rowmajor.Set({0, 0}, 1.0); + rowmajor.Set({0, 1}, 2.0); + rowmajor.Set({1, 0}, 3.0); + rowmajor.Set({1, 1}, 4.0); - EXPECT_EQ(*rowmajor, *colmajor); + EXPECT_EQ(rowmajor, colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_EQ(*tuple1, *tuple2); + auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix}); + EXPECT_EQ(tuple1, tuple2); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_NE(*tuple1, *reversed_tuple); + auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar}); + EXPECT_NE(tuple1, reversed_tuple); // Tuple with different value. auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_NE(*tuple1, *different_tuple); + auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix}); + EXPECT_NE(tuple1, different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { @@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) { // tuple, the other is a clone of the element in the original tuple. auto vector_clone = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - EXPECT_EQ(*vector, *vector_clone); + EXPECT_EQ(vector, vector_clone); auto vector_reversed = LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); - EXPECT_NE(*vector, *vector_reversed); + EXPECT_NE(vector, vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto tuple = LiteralUtil::MakeTuple({&element1, &element1}); // Tuples should always return false for IsAll. - EXPECT_FALSE(tuple->IsAll(0)); - EXPECT_FALSE(tuple->IsAll(1)); + EXPECT_FALSE(tuple.IsAll(0)); + EXPECT_FALSE(tuple.IsAll(1)); } // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto scalar = LiteralUtil::CreateR0(0.0); auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_EQ(*tuple, *x); + auto x = Literal::CreateFromShape(tuple.shape()); + EXPECT_EQ(tuple, x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false).IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255).IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0).IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001).IsAll(42)); - EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100}).IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001}).IsAll(100)); - EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}}).IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}}).IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}}).IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) - ->IsAll(-1)); + .IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) - ->IsAllFloat(.5)); + .IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE( - LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(scalar_zero->IsZero({})); - EXPECT_FALSE(scalar_one->IsZero({})); + EXPECT_TRUE(scalar_zero.IsZero({})); + EXPECT_FALSE(scalar_one.IsZero({})); auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(array->IsZero({0, 1})); - EXPECT_TRUE(array->IsZero({0, 2})); - EXPECT_TRUE(array->IsZero({1, 1})); - EXPECT_FALSE(array->IsZero({1, 2})); + EXPECT_FALSE(array.IsZero({0, 1})); + EXPECT_TRUE(array.IsZero({0, 2})); + EXPECT_TRUE(array.IsZero({1, 1})); + EXPECT_FALSE(array.IsZero({1, 2})); auto complex_zero = LiteralUtil::CreateR0(0.0f); auto complex_nonzero = LiteralUtil::CreateR0(0.5f); - EXPECT_TRUE(complex_zero->IsZero({})); - EXPECT_FALSE(complex_nonzero->IsZero({})); + EXPECT_TRUE(complex_zero.IsZero({})); + EXPECT_FALSE(complex_nonzero.IsZero({})); } template @@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = data->Relayout(layout01); - EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_EQ(*data, *data01); + auto data01 = data.Relayout(layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01)); + EXPECT_EQ(data, data01); - auto data10 = data->Relayout(layout10); - EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_EQ(*data, *data10); + auto data10 = data.Relayout(layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10)); + EXPECT_EQ(data, data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Transpose(/*permutation=*/{}); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](absl::Span indices, float value) { - EXPECT_EQ(value, original->Get( + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get( {indices[2], indices[3], indices[0], indices[1]})); }); } @@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. auto dim0minor_relaid_to_dim0major = - literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); + literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = - literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); + literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->element_count(), 6); - EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_EQ(mat_dim0minor.element_count(), 6); + EXPECT_THAT(mat_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->data(), + auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_); + EXPECT_THAT(relaid_mat_to_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->element_count(), 6); - EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_EQ(mat_dim0major.element_count(), 6); + EXPECT_THAT(mat_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->data(), + auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_); + EXPECT_THAT(relaid_mat_to_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->element_count(), 12); + EXPECT_EQ(lit_dim0minor.element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->data(), + EXPECT_THAT(lit_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->data(), + EXPECT_THAT(relaid_lit_to_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->element_count(), 12); - EXPECT_THAT(lit_dim0major->data(), + EXPECT_EQ(lit_dim0major.element_count(), 12); + EXPECT_THAT(lit_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->data(), + auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_); + EXPECT_THAT(relaid_lit_to_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { auto input = LiteralUtil::CreateR0(1); - auto result = input->Slice({}, {}); - EXPECT_EQ(*input, *result); + auto result = input.Slice({}, {}); + EXPECT_EQ(input, result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = input->Slice({3}, {4}); + auto result = input.Slice({3}, {4}); auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR2U32) { auto input_3x4 = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto result = input_3x4.Slice({0, 2}, {2, 4}); auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_EQ(*input_2x3x2, *result); + auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_EQ(input_2x3x2, result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = LiteralUtil::CreateR1({77}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { @@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { @@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { bfloat16 h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { @@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { bfloat16 h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { @@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { bfloat16 h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { @@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { output.PopulateWithValue({4, 2}); auto expected = LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = input->Replicate(3); + auto output = input.Replicate(3); auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_EQ(*output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, CopySliceFrom) { @@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](absl::Span indexes) { - source->Set(indexes, ++seqnr); + source.Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step, init_proc); auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); @@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = blank->Get(blank_indexes); - matched = (bval != 0 && bval == source->Get(source_indexes)); + auto bval = blank.Get(blank_indexes); + matched = (bval != 0 && bval == source.Get(source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = LiteralUtil::CreateR0(0); auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(zero->CopyFrom(*nine)); - EXPECT_EQ(*zero, *nine); + TF_EXPECT_OK(zero.CopyFrom(nine)); + EXPECT_EQ(zero, nine); auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); - EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); - EXPECT_EQ(vect->Get({4}), 17); + TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {})); + EXPECT_EQ(zero.Get({}), 17); + TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {})); + EXPECT_EQ(vect.Get({4}), 17); } TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { @@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); - EXPECT_EQ(*nine, *const_nine); + TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0})); + EXPECT_EQ(nine, const_nine); } { // Copy 0 element to destination with zero elements. - const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); - EXPECT_EQ(*empty, *const_empty); + TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0})); + EXPECT_EQ(empty, const_empty); } } @@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { TEST_F(LiteralUtilTest, CopyFromArrays) { auto scalar_42 = LiteralUtil::CreateR0(42.0); auto scalar_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*scalar_42, *scalar_123); - TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*scalar_42, *scalar_123); - EXPECT_EQ(scalar_42->Get({}), 123.0f); + EXPECT_NE(scalar_42, scalar_123); + TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(scalar_42, scalar_123); + EXPECT_EQ(scalar_42.Get({}), 123.0f); auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); - EXPECT_NE(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); - TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); + EXPECT_NE(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 5.0f); } TEST_F(LiteralUtilTest, CopyFromTuples) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {matrix.get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get()}); + Literal inner_elements[] = {LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0})}; + Literal inner_tuple = LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}); + Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(-5).get(), - LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); + Literal int32_minus5 = LiteralUtil::CreateR0(-5); + Literal double_2_4 = LiteralUtil::CreateR1({2.0, 4.0}); + Literal tuple = + LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal}); - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 44.0); // Overwrite the inner tuple element of nested_tuple with the contents of // 'tuple'. - TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{})); + TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), - LiteralUtil::CreateR0(4).get()}); + Literal elements[] = {LiteralUtil::CreateR0(-2), + LiteralUtil::CreateR0(4)}; + Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), 4); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), 4); // Copy from one element to the other. - TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{0})); + TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), -2); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), -2); } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto vector = LiteralUtil::CreateR1({5.0, 7.0}); - Status status = matrix->CopyFrom(*vector); + Status status = matrix.CopyFrom(vector); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); @@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - Literal* l1 = m1.get(); - const char* d1 = reinterpret_cast(l1->data().data()); + Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + const char* d1 = reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l2 = m2.get(); - const char* d2 = reinterpret_cast(l2->data().data()); + const char* d2 = reinterpret_cast(m2.data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->Populate(generator)); + TF_EXPECT_OK(literal.Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->PopulateParallel(generator)); + TF_EXPECT_OK(literal.PopulateParallel(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32)); - EXPECT_EQ(*expected, *converted); + EXPECT_EQ(expected, converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); // clang-format on - std::unique_ptr conv; + Literal conv; - conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u32); + conv = s8.Convert(U32).ConsumeValueOrDie(); + EXPECT_EQ(conv, u32); - conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = s8.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u64); + conv = s8.Convert(U64).ConsumeValueOrDie(); + EXPECT_EQ(conv, u64); - conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s64); + conv = s8.Convert(S64).ConsumeValueOrDie(); + EXPECT_EQ(conv, s64); - conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *pred); + conv = s8.Convert(PRED).ConsumeValueOrDie(); + EXPECT_EQ(conv, pred); - conv = bf16->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = bf16.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = bf16->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = bf16.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *int32_pred); + conv = pred.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, int32_pred); - conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f32.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f64.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = s32.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f64.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = s32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = u32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = s32.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - conv = f16->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = f16.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(S32).status().code(), + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) { tensorflow::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->BitcastConvert(F32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); - Status status = literal->BitcastConvert(F64).status(); + Status status = literal.BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE( absl::StrContains(status.error_message(), "bit widths are different")); @@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { p.add_preds((i % 2) == (len % 2)); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - ASSERT_EQ(len, literal->data().size()); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal.data().size()); int i = 0; - for (bool value : literal->data()) { + for (bool value : literal.data()) { EXPECT_EQ((i % 2) == (len % 2), value); ++i; } @@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h2(2.0f); auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l = m.get(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->data().size()); + EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape())); + EXPECT_EQ(4, m.data().size()); - LiteralProto p = l->ToProto(); + LiteralProto p = m.ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); @@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - auto r = literal->data(); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); ASSERT_EQ(4, r.size()); EXPECT_EQ(h1, r[0]); EXPECT_EQ(h2, r[1]); @@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); - EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); - EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(scalar, {}), scalar); + EXPECT_EQ(LiteralSlice(matrix, {}), matrix); + EXPECT_EQ(LiteralSlice(tuple, {}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple); EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(tuple, {0}), scalar); + EXPECT_EQ(LiteralSlice(tuple, {1}), matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix); + EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar); } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralSlice(*nested_tuple); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 1.0f); + const auto nested_tuple_view = LiteralSlice(nested_tuple); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); - nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 555.0f); + nested_tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 555.0f); @@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); - const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { } TEST_F(LiteralUtilTest, LiteralMove) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - Literal literal(std::move(*matrix)); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(matrix)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get(), - &nil_literal}); - - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); - std::vector elements = nested_tuple->DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + Literal inner_elements[] = { + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0}), + }; + Literal tuple_elements[] = { + LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}), + }; + Literal nested_tuple = LiteralUtil::MakeTuple( + {&tuple_elements[0], &tuple_elements[1], &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + std::vector elements = nested_tuple.DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); - elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); - elements.push_back(std::move(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get()}) - - )); + elements.push_back(LiteralUtil::CreateR0(1.0)); + elements.push_back(LiteralUtil::CreateR1({4, 8})); + std::vector inner_elements; + inner_elements.push_back(LiteralUtil::CreateR0(42)); + inner_elements.push_back(LiteralUtil::CreateR1({23.0, 44.0})); + elements.push_back( + LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { Literal literal; EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - literal = std::move(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(matrix); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { } TEST_F(LiteralUtilTest, LiteralSliceCopy) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralSlice(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralSlice(matrix); LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); @@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42.0).get(), - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); - tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); - - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), - 3.0); - tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + Literal elements[] = { + LiteralUtil::CreateR0(42.0), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + }; + auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0); + tuple.Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), -4.0); } TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { // Literals constructed using CreateFromShape should be zero initialized. - std::unique_ptr scalar_f32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); - EXPECT_EQ(scalar_f32->Get({}), 0.0); - EXPECT_TRUE(scalar_f32->IsAll(0)); - - std::unique_ptr vector_s32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); - EXPECT_EQ(vector_s32->Get({0}), 0); - EXPECT_EQ(vector_s32->Get({1}), 0); - EXPECT_EQ(vector_s32->Get({2}), 0); - EXPECT_TRUE(vector_s32->IsAll(0)); - - std::unique_ptr tuple = - Literal::CreateFromShape(ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); - - EXPECT_EQ(tuple->Get({}, {0}), 0.0); - EXPECT_EQ(tuple->Get({0}, {1}), false); - EXPECT_EQ(tuple->Get({1}, {1}), false); - EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); + Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32.Get({}), 0.0); + EXPECT_TRUE(scalar_f32.IsAll(0)); + + Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32.Get({0}), 0); + EXPECT_EQ(vector_s32.Get({1}), 0); + EXPECT_EQ(vector_s32.Get({2}), 0); + EXPECT_TRUE(vector_s32.IsAll(0)); + + Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple.Get({}, {0}), 0.0); + EXPECT_EQ(tuple.Get({0}, {1}), false); + EXPECT_EQ(tuple.Get({1}, {1}), false); + EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1657,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto one_f32 = LiteralUtil::CreateR0(1.0); auto two_f32 = LiteralUtil::CreateR0(2.0); auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); @@ -1665,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( - {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + {&one_f32, &vector_half, &matrix_pred, &matrix_pred}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + auto nested_tuple = + LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { - return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + return Literal::CreateFromProto(literal.ToProto()).ValueOrDie(); }; - EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); - EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); - EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); - EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); - EXPECT_EQ(*tuple, to_from_proto(*tuple)); - EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); + EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); - EXPECT_NE(*one_f32, *two_f32); - EXPECT_NE(*one_f32, to_from_proto(*two_f32)); + EXPECT_NE(one_f32, two_f32); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { @@ -1802,11 +1788,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { TEST_F(LiteralUtilTest, SortSparseElements) { auto literal = LiteralUtil::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); - literal->AppendSparseElement({2, 3, 4}, 2.0); - literal->AppendSparseElement({3, 4, 5}, 3.0); - literal->AppendSparseElement({1, 2, 3}, 1.0); - literal->SortSparseElements(); - EXPECT_EQ(literal->ToString(false), + literal.AppendSparseElement({2, 3, 4}, 2.0); + literal.AppendSparseElement({3, 4, 5}, 3.0); + literal.AppendSparseElement({1, 2, 3}, 1.0); + literal.SortSparseElements(); + EXPECT_EQ(literal.ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1816,57 +1802,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), "false"); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(int64{2})); EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(double{2.0})); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(static_cast(half{2.0}))); EXPECT_EQ(LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = LiteralUtil::CreateR0(9); + Literal literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), - /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 613449cf10c785de55e8474c0ee35f78e8ed92b4..0cb1ae35f4ad31f091063d78ed32c1463be8ee0a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -45,7 +45,7 @@ using absl::StrCat; // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(LiteralSlice literal) { +Literal ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -56,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -67,14 +67,14 @@ std::unique_ptr ConvertType(LiteralSlice literal) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); - auto dest = result->data(shape_index); + auto dest = result.data(shape_index); for (int64 i = 0; i < src.size(); ++i) { dest[i] = static_cast(src[i]); } } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); + TF_CHECK_OK(result.CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); } } }); @@ -83,53 +83,52 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( +/* static */ Literal LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( +/* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( +/* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -/* static */ std::unique_ptr LiteralUtil::CreateToken() { - return absl::make_unique(ShapeUtil::MakeTokenShape()); +/* static */ Literal LiteralUtil::CreateToken() { + return Literal(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case C64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -145,30 +144,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case C64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -184,42 +182,36 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case F32: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -232,40 +224,34 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case F32: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -275,31 +261,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ Literal LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = absl::make_unique( + Literal literal( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( - absl::string_view value) { - auto literal = absl::make_unique( - ShapeUtil::MakeShape(U8, {static_cast(value.size())})); +/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { + Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { - literal->Set({i}, value[i]); + literal.Set({i}, value[i]); } return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( +/* static */ Literal LiteralUtil::ReshapeSlice( absl::Span new_dimensions, absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; @@ -309,13 +293,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = absl::make_unique( + Literal new_literal( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); + Shape shape_with_layout = new_literal.shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Copy data into new literal, element-by-element. @@ -326,40 +310,40 @@ std::unique_ptr ConvertType(LiteralSlice literal) { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " @@ -376,97 +360,82 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 8 bit types. case S8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 16 bit types. case BF16: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 32 bit types. case F32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 64 bit types. case C64: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); } } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ Literal LiteralUtil::MakeTuple( absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( +/* static */ Literal LiteralUtil::MakeTupleFromSlices( absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( - std::vector> elements) { +/* static */ Literal LiteralUtil::MakeTupleOwned( + std::vector elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_shapes.push_back(element->shape()); + element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( - literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2d6084a67a3b966d054103df0f06ddb82d0d6525..2b181621ed92be8952ccec19e0d4229c494b9f47 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -69,36 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template - static std::unique_ptr CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template - static std::unique_ptr CreateR1(absl::Span values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template - static std::unique_ptr CreateR2( + static Literal CreateR2( std::initializer_list> values); template - static std::unique_ptr CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list> values, const Layout& layout); template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list>> + values); template - static std::unique_ptr CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list>> values, const Layout& layout); template - static std::unique_ptr CreateR4( + static Literal CreateR4( std::initializer_list>>> values); template - static std::unique_ptr CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list>>> values, @@ -139,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template - static std::unique_ptr CreateSparse( - absl::Span dimensions, SparseIndexArray indices, - absl::Span values, bool sort = true); + static Literal CreateSparse(absl::Span dimensions, + SparseIndexArray indices, + absl::Span values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -155,130 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithDescendingLayout( + static Literal CreateFullWithDescendingLayout( absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template - static std::unique_ptr CreateFromArray(const Array& values); + static Literal CreateFromArray(const Array& values); template - static std::unique_ptr CreateFromArrayWithLayout( - const Array& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array& values, + const Layout& layout); template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values); + static Literal CreateR2FromArray2D(const Array2D& values); template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values); + static Literal CreateR3FromArray3D(const Array3D& values); template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values); + static Literal CreateR4FromArray4D(const Array4D& values); template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(absl::string_view value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template - static std::unique_ptr CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template - static std::unique_ptr MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr MakeTuple( - absl::Span elements); + static Literal MakeTuple(absl::Span elements); - static std::unique_ptr MakeTupleFromSlices( - absl::Span elements); + static Literal MakeTupleFromSlices(absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector> elements = ...; + // std::vector elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr MakeTupleOwned( - std::vector> elements); + static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of unique_ptrs to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector>) + // Simply relying on the MakeTupleOwned(std::vector) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr. + // The arguments to this function must all be Literal. template - static std::unique_ptr MakeTupleOwned( - std::unique_ptr... elements) { - std::array, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array arr{std::move(elements)...}; + std::vector v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, absl::Span dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr ReshapeSlice( - absl::Span new_dimensions, - absl::Span minor_to_major, const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span new_dimensions, + absl::Span minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -286,7 +275,7 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( + static StatusOr CreateRandomLiteral( const Shape& shape, const std::function)>& generator); @@ -297,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -307,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -322,44 +311,43 @@ class LiteralUtil { std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = absl::make_unique(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( - absl::Span values) { - auto literal = absl::make_unique( +/* static */ Literal LiteralUtil::CreateR1(absl::Span values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -384,14 +372,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -422,23 +410,22 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateSparse( +/* static */ Literal LiteralUtil::CreateSparse( absl::Span dimensions, SparseIndexArray indices, absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal.PopulateSparse(indices, values, sort); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -446,50 +433,48 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -514,7 +499,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -542,21 +527,20 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -565,33 +549,29 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout(absl::Span dimensions, - NativeT value) { - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); - literal->PopulateWithValue(value); +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal.PopulateWithValue(value); return literal; } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = absl::make_unique(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate( [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index bddb6641495b1ce6df2483b3b759e207b2f0ceec..0f86f9f35e105713aa3072a9ebf572d33d35d66d 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -40,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr> PackedLiteralReader::Read( - const Shape& shape, const Layout* layout) { +StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); @@ -58,14 +57,14 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type())); } - auto result = absl::make_unique(literal_shape); - result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + Literal result(literal_shape); + result.PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - absl::Span field = result->data(); + absl::Span field = result.data(); char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -86,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 98dccaa9a246520bf60217b96d67a13a24c34b4a..d6d2ff1521bab341b166c4f5c1dc0917e28573d8 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -41,8 +41,7 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr> Read(const Shape& shape, - const Layout* layout = nullptr); + StatusOr Read(const Shape& shape, const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884c810fd724ab88ad7d4beaf3e0a6cc7..b507a2ef79f1d7e9ae632744675dddf574490805 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367c7639c40ff17aee7b77305d4d34e33..f22fc8b8499dd4a5329276040331a2ed9e89bea9 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cd6e20b69366c064e20c6e0a7d1aebe6229690d8..9da5dc0d2d40cb10640fb0fd2c4c65b4f8e55346 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, return client->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr> TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number) { +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; LocalClient* client = GetOrCreateLocalClient(); @@ -141,9 +141,8 @@ StatusOr LocalShapedBuffer::FromLiteral( LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, relaid); } return ToBuffer(client, /*device_ordinal=*/0, argument); }(); @@ -151,7 +150,7 @@ StatusOr LocalShapedBuffer::FromLiteral( return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -StatusOr> LocalShapedBuffer::ToLiteral() const { +StatusOr LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr> CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( const std::vector& arguments, const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); @@ -169,7 +168,7 @@ StatusOr> CompiledLocalComputation::Execute( // Each replica populates a StatusOr result, but only replica zero actually // retrieves its literal value. - std::vector>> results(GetReplicaCount()); + std::vector> results(GetReplicaCount()); { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", GetReplicaCount()); @@ -198,9 +197,8 @@ StatusOr> CompiledLocalComputation::Execute( StatusOr pushed; if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, relaid); } else { pushed = ToBuffer(client, device_ordinal, argument); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 78b3c598b97294d2ba4deb72ec9c1251ef68b7cf..1d5dfe591175735d58a5fe555fffc8043fa4de7e 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); // Transfers a literal of the given shape from the outfeed of the given replica. // // The replica number is resolved to an appropriate device ordinal. -StatusOr > TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number); +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number); // Wraps a ScopedShapedBuffer produced by copying a literal "to // device," i.e. copying a literal to a scoped buffer via the local @@ -65,7 +65,7 @@ class LocalShapedBuffer { LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - StatusOr > ToLiteral() const; + StatusOr ToLiteral() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -117,7 +117,7 @@ class CompiledLocalComputation { // with optionally-specified argument layouts. The literals will be // re-laid out according to the corresponding elements of // shapes_with_layout. - StatusOr > Execute( + StatusOr Execute( const std::vector& arguments, const std::vector >& shapes_with_layout); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 450d3fe5af374f5b536183f963dcc5b19648a6c4..521490e76c138553c5cc6895412eadb35a939881 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -216,9 +216,9 @@ tensorflow::ImportNumpy(); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if ($1.ok()) { - std::unique_ptr value = $1.ConsumeValueOrDie(); + Literal value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -346,25 +346,25 @@ tensorflow::ImportNumpy(); // Literal -%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { +%typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); SWIG_fail; } - $1 = literal_status.ValueOrDie().get(); + $1 = &literal_status.ValueOrDie(); } -%typemap(out) std::unique_ptr { +%typemap(out) Literal { $result = numpy::PyObjectFromXlaLiteral(*$1); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); + $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); } %typemap(in) const std::vector& (std::vector temps) { @@ -375,13 +375,13 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); SWIG_fail; } - temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + temps.push_back(literal_status.ConsumeValueOrDie()); Py_DECREF(o); } $1 = &temps; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index fc6511bef566cb6f4e0d4e52972954de0792e959..b0aa024c7474cf8e6934432b2f364be464714999 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } } -StatusOr> XlaLiteralFromPyObject(PyObject* o) { +StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); - std::vector> elements; + std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); @@ -389,8 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR( - CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 8cae1751853f3cd18033ecf6edca40bf99c6d917..40ff2d9ad214cc4dcad42234fa296834cbc92882 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // To avoid transferring ownership of the data buffers that underlie // PyArrays and XLA literals, this function makes deep copies of all // array data. -StatusOr > XlaLiteralFromPyObject(PyObject* o); +StatusOr XlaLiteralFromPyObject(PyObject* o); // The following functions copy array data from the buffers underlying Numpy // ndarrays into those underlying XLA literals, and vice versa. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 9f1afa2671b34e80c0d2ea3050c58c2e87844c63..ceb5e74db7c3b9305e9d77068df9ae0a3690af8a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, - float init, - const absl::Span& window, - const absl::Span& stride, +ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, + absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; @@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + absl::Span window, absl::Span stride, + Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, + absl::Span window, + absl::Span stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique>(operand.n1(), operand.n2(), @@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -565,7 +558,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( *window.add_dimensions() = dim2; const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal->shape(), rhs_literal->shape(), + lhs_literal.shape(), rhs_literal.shape(), /*feature_group_count=*/1, window, dnums) .ConsumeValueOrDie(); @@ -585,18 +578,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr result_literal = + Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - absl::make_unique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); result->Each([&](absl::Span indices, float* value) { - *value = result_literal->Get(indices); + *value = result_literal.Get(indices); }); return result; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 9ce098029dbc35f6b4bab2efd77bee2b7e1a6255..8654fbb9b5e16c5ac13cb29aafeef8d142dbe39f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -177,47 +177,41 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const absl::Span& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span operand, float init, + absl::Span window, absl::Span stride, + Padding padding); static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span window, absl::Span stride, + Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -230,8 +224,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, bool same_padding); + absl::Span window, absl::Span stride, + bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -332,8 +326,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D(const absl::Span& input, - int64 start, int64 size) { + static std::vector ClampSlice1D(absl::Span input, int64 start, + int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 3ec0192148492c2516bf1c14fd4b960b08014388..a1b0f4045ff071454451f9fe3942ac974f4f47ac 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, actual_literal, ErrorSpec(0.0001)); } @@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, actual_literal, ErrorSpec(0.0001)); } @@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); - LiteralTestUtil::ExpectR1Equal({0}, *result); + LiteralTestUtil::ExpectR1Equal({0}, result); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, ErrorSpec(0.0001)); } @@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray3D) { @@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal, ErrorSpec(0.0001)); } @@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, - *actual_literal, ErrorSpec(0.0001)); + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray4D) { @@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray4D) { @@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { @@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 43fd8fe1bd0f41eb2ac5c42021a8ca4f63282646..84fe5b17d10fba8c9f44314bec2b827e98ff6b33 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal expected_literal = LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal, ErrorSpec(0.0001))); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e784663ff65eedbea965e9fd1a561da28e10a303..fb80c78f6852db7d69aeef752b5f692d47d58bed 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -352,6 +354,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,6 +405,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -498,6 +502,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -546,6 +551,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -568,6 +574,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1012,8 +1019,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1041,8 +1048,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1088,8 +1095,8 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1131,6 +1138,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1138,6 +1146,37 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], @@ -1185,9 +1224,9 @@ tf_cc_test( ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1199,13 +1238,14 @@ tf_cc_test( ) cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1219,15 +1259,15 @@ cc_library( ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1259,6 +1299,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1392,6 +1433,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1708,6 +1750,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1777,6 +1820,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -1953,6 +1997,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", + ":hlo_memory_scheduler", ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2236,6 +2281,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2314,6 +2360,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2394,12 +2441,11 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2428,6 +2474,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2494,6 +2541,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2611,6 +2659,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2888,6 +2937,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3d18fe3be24484ff68bcb3c4b41be57f920665eb..5458159d149c627b1121fd8a30e073b712542390 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr FoldConvInputPad(HloInstruction* convolution); + StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = absl::make_unique( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -1469,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -1572,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1607,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2057,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (pad_literal == reduce_init_literal) { return true; } - auto converted_pad_literal = pad_literal.ConvertToShape( - reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + auto converted_pad_literal = + pad_literal.ConvertToShape(reduce_init_value->shape()); if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2212,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), - {})); - } - + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr { - if (lhs->opcode() != HloOpcode::kPad) { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(lhs->operand(1), 0)) { + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { return false; } - const auto& padding = lhs->padding_config(); - - // Can't pad batch or feature dims. - for (int64 dim : - {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); } + } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = window; - for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); - // Edge padding composes with itself in the straightforward way, but - // composing interior padding is nontrivial, and we cowardly refuse to - // think about it. If we see interior padding in either the kPad or conv, - // bail if there's any sort of padding in the other. - if (p.interior_padding() != 0 && - (w.padding_low() != 0 || w.padding_high() != 0 || - w.base_dilation() != 1)) { - return false; - } - if (w.base_dilation() != 1 && - (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0)) { - return false; - } + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - w.set_padding_low(w.padding_low() + p.edge_padding_low()); - w.set_padding_high(w.padding_high() + p.edge_padding_high()); - if (p.interior_padding() != 0) { - CHECK_EQ(w.base_dilation(), 1); - w.set_base_dilation(1 + p.interior_padding()); - } - } +StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs->mutable_operand(0), rhs}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } - if (folded_input_pad) { - return Status::OK(); + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; } - // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr { - if (rhs->opcode() != HloOpcode::kPad) { - return false; - } + const auto& padding = rhs->padding_config(); - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(rhs->operand(1), 0)) { + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - const auto& padding = rhs->padding_config(); + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - // Can't pad or dilate feature dims. - for (int64 dim : {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = convolution->window(); - for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - - // We can only do this transformation if p adds dilation to the filter -- - // edge padding on the filter is not supported in conv. - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { - return false; - } - - // Nothing to do if the kPad for this dim is entirely a nop. - if (p.interior_padding() == 0) { - continue; - } + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } - // We cowardly refuse to think about how dilation composes with itself; - // bail if both the kPad and conv have dilation on this dimension. - if (w.window_dilation() > 1) { - return false; - } - CHECK_EQ(w.window_dilation(), 1); - w.set_window_dilation(1 + p.interior_padding()); - w.set_size(rhs->operand(0)->shape().dimensions( - dnums.kernel_spatial_dimensions(dim))); + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs, rhs->mutable_operand(0)}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - if (folded_filter_pad) { - return Status::OK(); - } +StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } const Shape& input_shape = lhs->shape(); @@ -2388,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2399,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2425,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2467,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2479,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, convolution->precision_config())); - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a0db4563fb1324f48eb6bf0577d91b81bb5a3e24..3fc1ba24271b40de0a24ed4c957cd83aca736f55 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec281ae68fe76bac4029058997c44b1f7e71aeae..30d33e0d3531bb5e931ebfa0b60c91523dd0cb44 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aba0d9bb5b977d89656580df46838eefb8cd6662..f7ac8f5482908af104554a1cf812370b9098cda7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 6363a21c3bafe8353a6ebfde405bb7a3736c2074..5f93740887aa7e61458990992fe0573883ff056d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 933cf873e05fe11b63846f85b97eb49bd21e5a6c..cef0eba14e9dd463d6c32b047211bf25a84478f6 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); @@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 545a6ecfb1fca88c2c759e820f9d87a38b1941ca..58f78f8e24d0bc00a63e3583828cf8e01ae4531a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN( - auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape(), - /*round_f32_to_bf16=*/true)); + TF_ASSIGN_OR_RETURN(auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 388fd5df99a6d6649d45e794032196f2f2b20bc6..e032b5c624c0151fd63c870e0f21ec97656d625f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f0af57626599adef0b477d506e5d3afd9c1c315..65fa951afe3e60652413206913640af38f5bb824 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 5a231c173dd7269a724a5d8ec5ce819d6d27942c..795beb9ff5ceb2998a85fbd03d8bb1d3b2febc12 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 414bfe79990212109e92b92af3fd88c8729fb22a..17e50905059ad2c92784d14132c1cb1f46f35ade 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b7484313329104eec1ce71a150b47d8330c9..34f3f914d593bc603c4964663f9cafb70a136fd3 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 5d85a3f173d50a964420e720f5c9b416731d948c..e6b566543594a86eb5369ee9b7440f62618f6c5a 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -91,6 +91,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { module->AddEmbeddedComputation(just_false.Build()); HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index e5a6c28478a7ebf87878c3937069f15cafe12615..96bd2616f5607de888a096f8392ceb68490276e3 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0826380f65460b851482fd61928eed84c3744aea..0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); auto new_filter = add( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2368ac8c6aa18e61b6e74c0f748424a7e27b784b..8cc522a59e9805ec86e9e69c8d6e5fa1a3ab682d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -122,7 +122,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", @@ -801,6 +801,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -822,6 +823,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -946,6 +948,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -971,6 +974,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 05792795a10ad87a0261d41c05e5c2001a88bed1..2083f440fdd971db1b675d005664d25e6de53dbe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7b60759944cc9f06ba612482203a226c395b7d6..18fc144efe0023c0893adfcb16eda3341c0938d3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,12 +77,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446ea9188940f930bcadf753bd3e6b79e3..c9fb34be1cd582c71618c770c892058c233c571a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2..be1208fb2df2a1a11a093810b5f6c2a83f468062 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 942e2ddd3940fffd5d87518f059beaced3cdc925..55d5925642a97b1a0425c092c82070d4b8e59df3 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr actual = result.ConsumeValueOrDie(); + xla::StatusOr result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); LOG(INFO) << absl::StrFormat("computation took %dns", profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 7d8e51f909e3db699b745f94a6c625407bc4a6e3..1a3d82de954318368d61e3feeb0345dc592dcd8b 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index f11aff0573d4245eb4f3cdec7fd650da505698f5..c55206eee7ae3c6e4410c59aebf529de98fd2de8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,6 +48,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 22721051e54e2cf9590b60333c51d1d028bb28e9..1deb412064b02988a8d4a6d726969c948d354d47 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, result, error_spec_); } TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { @@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b..5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({3, 4}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({5, 6}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8}), + LiteralUtil::CreateR0(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8, 9}), + LiteralUtil::CreateR0(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({4, 5, 6}), + LiteralUtil::CreateR0(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index bb105194f1c9001ca4d9fff9174e1ea7e5d8b72a..7af51db55af44ae1e437ea8e4de7427012cad82f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 1b3be199f632a2aa6bd2c5a3820c7c5ce9b1382e..852f34e06df35242b13110ae4411b8c969c26019 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -56,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241ed02bbb7e9fde9b6d767c002435e777..5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 4ed91ef18768d09c252d1b73890637227f0ce717..bec02e14f951c6d905b7329be5c02896984279d0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6791e15ee03bc9720a9a1391b1d1ed18f58c2941..64b96836280718f13ac5ee9f4a497ed54a273b19 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -173,6 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", @@ -370,6 +372,8 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", + ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -395,6 +399,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -813,9 +818,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], ) @@ -832,6 +837,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -901,6 +907,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863dd2cfe69ad70168be40cdea5bc7017f..3a23ac1d634161628b2bd2589d0260022868ba36 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,62 +31,32 @@ namespace gpu { using se::dnn::AlgorithmDesc; -ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, - int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - feature_group_count_(feature_group_count), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} - Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + CudnnConvParams params; + + params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); + params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); + params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. void* result_ptr = [&] { - switch (convolution_kind_) { + switch (params.kind) { case CudnnConvKind::kForward: - return output_data.opaque(); + return params.output_buf.opaque(); case CudnnConvKind::kBackwardInput: - return input_data.opaque(); + return params.input_buf.opaque(); case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); + return params.filter_buf.opaque(); } }(); void* ptrs[] = {result_ptr, scratch.opaque()}; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 68d67c40c56145a137398540e90b75b33642589f..d7d1f91fba7239ed1670119f5df623d025c1d368 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // // Note that "output" here doesn't refer to the output from running this // thunk, but rather to the "output" of a hypothetical forward convolution // that corresponds to this input+filter+output triple. That is, the result // generated by this thunk is "output" for forward convs, "input" for // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, - int64 feature_group_count, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + BufferAllocation::Slice input_slice, + BufferAllocation::Slice filter_slice, + BufferAllocation::Slice output_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + input_buffer_(std::move(input_slice)), + filter_buffer_(std::move(filter_slice)), + output_buffer_(std::move(output_slice)), + scratch_buffer_(std::move(scratch_slice)), + tuple_result_buffer_(std::move(tuple_result_slice)) {} ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 feature_group_count_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice filter_buffer_; + BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5c2555148ae5de4a15e5a5f003b4783c64a20e9c..f528e62b175758b4f4cf5ecff4dab4810ede2ed3 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/mutex.h" @@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // caching would speed up compilation a lot. StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr) { + const HloCustomCallInstruction* instr) { + CudnnConvParams params; + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); + + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); CHECK_EQ(input_shape.element_type(), output_shape.element_type()); // TODO(timshen): for now only check fp16. It can be expanded to other types, @@ -216,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); size_t left_over_bytes = buffer.size() % 4; CHECK_EQ(0, left_over_bytes % 2); @@ -252,33 +244,46 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( DeviceMemoryBase left_over( static_cast(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - }; - initialize_f16(input_buf); - initialize_f16(filter_buf); - initialize_f16(output_buf); - } else { - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the - // inputs might affect performance if e.g. the inputs contain denormals, and - // this is easy enough. - stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()); - } + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); + } + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + TF_ASSIGN_OR_RETURN(params.input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(params.filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(params.output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); + + initialize_buffer(params.input_buf); + initialize_buffer(params.filter_buf); + initialize_buffer(params.output_buf); DeviceMemoryBase* result_buf = [&] { - switch (kind) { + switch (params.kind) { case CudnnConvKind::kBackwardFilter: - return &filter_buf; + return ¶ms.filter_buf; case CudnnConvKind::kBackwardInput: - return &input_buf; + return ¶ms.input_buf; case CudnnConvKind::kForward: - return &output_buf; + return ¶ms.output_buf; } }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); + input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; @@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional first_algorithm; for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, - filter_buf, output_buf, &scratch_allocator, window, dnums, - feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + params.algorithm = AlgorithmConfig(alg); + bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, + &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { const bool crash_on_checking_failure = @@ -374,34 +377,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - StatusOr> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = - PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr->feature_group_count(), - instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } + StatusOr> alg_scratch_and_tc = + PickBestAlgorithm(Cast(instr)); if (!alg_scratch_and_tc.ok()) { LOG(ERROR) << alg_scratch_and_tc.status(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 0cb01161b023b900c8c4b1386b679fe2bd5db802..f79b113f8fac0190adef9a8d68d1617710b1402c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); StatusOr> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr); + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 9bf721ecd2ad938e71f88a6fc65cd2d3bd25161e..228379a2488a8564564e8b5e35a863553f4bbac2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include #include #include @@ -59,8 +60,6 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward filter. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -218,13 +217,16 @@ std::tuple MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardInput( - HloInstruction* conv) { +std::tuple +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward input. + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -232,51 +234,38 @@ std::tuple MatchBackwardInput( // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else if (reverse_filter->IsConstant()) { - // If the filter is a constant, we're willing to pattern-match to a - // backwards-input conv, on the theory that - // - // a) reversing a constant is free, and - // b) even if the user specified this filter as reverse(constant), we would - // long ago have constant-folded away the reverse. - // - // If the constant has any other uses, reversing it isn't entirely free, - // since we'd now have two constants to keep in memory. But hopefully it's - // free enough. - // - // TODO(jlebar): Should we do this even if the filter is not a constant? - // Reversing a non-constant filter is probably cheaper than padding the - // input! - - // Nothing to do, just fall through. - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -401,26 +390,64 @@ std::tuple MatchBackwardInput( } } - // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. - // This simplifies things for our caller, and algebraic-simplifier will later - // remove any unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse) { + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); } // Tries to rewrite a single convolution into a call to cudnn. @@ -431,6 +458,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { @@ -439,13 +467,8 @@ StatusOr RunOnInstruction(HloInstruction* conv) { window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput(conv->shape(), conv->mutable_operand(0), rhs, window, dnums, conv->feature_group_count()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index bda8ebe579778107a6569079ab94d5822dff6749..d237f8930b74d460ad3d4602670a5afb19b496a2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); ParseAndVerifyModule(absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 05125e9d1fb3cd03cb72b7854fc28c767b49fd64..2a86ac265e4d6a6502162ac33b04b0ee362ce49e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, - output_buf, &scratch_allocator, window, dnums, feature_group_count, - algorithm, stream, profile_result); + return RunCudnnConvolution(params, &scratch_allocator, stream, + profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = params.output_shape->element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, + stream, profile_result); case F32: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index a1b4fc71d0cac3e5ea067ca7941b07cbade8d7cc..381aa37a1b1405e00d62adf9855e9229482f5b86 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -47,6 +47,20 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter }; +struct CudnnConvParams { + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; +}; + // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); @@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. +// At the moment convolution with half data type is implemented with cudnn +// PSEUDO_HALF configuration, that is, the input values are half and the +// internal computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index ea9376e101ff9bfda3cdc78b67a83d369674511e..02a0d028c118aba23996f9b97d05443bb4a00c88 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 59ade96f7d262a71add4a15680690ac2fc7b4821..b857fa775a76ec999b505a2a64332cc0c54cf00b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c954cafcbe241189ee79a0995683513..27a4d0b601f3807fe6b94dd6171a44f292921ede 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 20d523abe0552f0bc22c365007c096666ec888f6..22f43bc08bd08abd735f88f32f28c528499cf3d2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params) { + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); + const auto& target = custom_call->custom_call_target(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + + params->window = &custom_call->window(); + params->dnums = &custom_call->convolution_dimension_numbers(); + params->feature_group_count = custom_call->feature_group_count(); + params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + + if (target == kCudnnConvForwardCallTarget) { + params->kind = CudnnConvKind::kForward; + params->input_shape = &lhs_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &conv_result_shape; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params->kind = CudnnConvKind::kBackwardInput; + params->input_shape = &conv_result_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &lhs_shape; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params->kind = CudnnConvKind::kBackwardFilter; + params->input_shape = &lhs_shape; + params->filter_shape = &conv_result_shape; + params->output_shape = &rhs_shape; + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59c65fc2686cd4a00a3770ebaedf637e8f556828..09c455cc1e137b4a9836a58d5b70e62a4bfa120a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,9 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Populates params using conv, which must be a custom-call to a cudnn +// convolution. Does not modify any buffers in the params. +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f91cc00d713deeb0e3adab0dca968b2f87565376..b669881026276eefe2ca6cbea74d79604dd13066 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); - std::unique_ptr thunk; + BufferAllocation::Slice input_slice, filter_slice, output_slice; + if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = rhs_slice; + output_slice = conv_result_slice; } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = conv_result_slice; + filter_slice = rhs_slice; + output_slice = lhs_slice; } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = conv_result_slice; + output_slice = rhs_slice; } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); } - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique( + Cast(custom_call), input_slice, filter_slice, + output_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f6325b33680629b7e3d3814b088582a5007de6dc..dfdcf1875dd3f5749bd1fd95ad0eeb8c11955887 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); - // CudnnConvolutionRewriter may add instructions of the form - // reverse(constant), which it expects will be simplified by constant - // folding. - pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // pairs that TupleSimplifier fixes. pipeline.AddPass(); } + // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add + // instructions which can be simplified by constant folding. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index fa84d7722351b68770b876e3880b472eec3233d7..b0061fa6558ac92bffd3dff13e736421a62dc484 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { - // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. static constexpr int64 kDesiredNumFeaturesFactor = 8; @@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9d85d746d84908eaa8d720bc3cccc475d81710f3..2a6415d0b6c973cb72c30b7a803b5f603c1d5e4d 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8f0dedfa40e57543147af6961a16e6b66a320298..c4f43cc9a614283acb376b5f98e4976615b590ad 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,14 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr CreateNewModule() { HloModuleConfig config; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 4550f36fdfc097632fed4956fcd3e42ef8a919c5..780539c164277f14c2bd964024f7c3ca179f4ada 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317d253fd6d50e9d98949cad4eaebfe7b..f8120a5fa00ce38644cd85c54d5ef65701be1eda 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D(array).get(), - LiteralUtil::CreateR0(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D(array), + LiteralUtil::CreateR0(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 40183de96ee363996e6b0b883a78e7a8b5d13ab2..9a61f8ac5a62e38e687a93890eb33481a01d51c8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -26,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 00a25db467dc795b296b96b65f0eb56f6762f42f..957c4a68915934796a315f2443c90e571e942e75 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module.get()); + HloSchedule schedule(module); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -233,7 +233,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 93ec2c9438bf11b8119a947c4465926810129b7f..b19ec126382d143b6ded401f2fad56f950d04bbd 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -309,6 +309,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 233d2199d139770fd1cab2a2d1485211f0fcd44a..8c6903d76628f87b01de044f1e49de367bf38110 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -562,9 +562,11 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 8a45939c61755876555bc35c49d7d6c781f8b4fe..f837816cea78d78bb3d605dd91e81cac39036268 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,10 +76,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } - std::unique_ptr result = evaluator->TryEvaluate(instruction); + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 07cd1efc1208309770478885532e0284bdb1fbcc..3e0def5d26a0033d954a776c1c32d6c35acfb505 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { root->literal().EachCell( [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get(rindexes)); + matched = matched && (value == literal_clone.Get(rindexes)); }); EXPECT_TRUE(matched); } @@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); + ParseAndVerifyModule(kConstantFoldReduce); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_TRUE(result); - EXPECT_EQ(6, module->entry_computation() + EXPECT_EQ(6, module() + .entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); - HloInstruction* add = module->computations().begin()->root_instruction(); + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_FALSE(result); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a3fcc0fefa5d132166849ecb4a5877626bee3530..b76c50bb5b99cf4c9e6d4e04c240e8159acfc338 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -321,18 +321,17 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, absl::Span broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index eb6affadc800d9d5cf7b143386b46f3e8c608e63..e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR0(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index e09d5868f23e9d5a3b239f6831c31c41966ffa2e..9b18b0284f63c25934c1b7118dc8973caa62cadc 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index d0d955fea8e17b7ae5f0099f5c51cf00572d436a..06b6d5b5592c5849dd247fc19fc52ab0a2113fe8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -54,9 +54,8 @@ namespace xla { namespace { template -StatusOr> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -94,9 +93,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -105,9 +104,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr> Compare( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -125,9 +124,9 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); @@ -206,11 +205,21 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(module, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -224,11 +233,21 @@ StatusOr> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(computation, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); @@ -247,18 +266,27 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate(instruction, arg_literal_ptrs); } -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction) { +StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -274,21 +302,22 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr> HloEvaluator::EvaluateWithSubstitutions( +StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions) { @@ -299,7 +328,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -316,12 +345,12 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -331,10 +360,10 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -343,14 +372,14 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr> HloEvaluator::EvaluateDotOp( +StatusOr HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex { // there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( int64 index_vector_dim, const Literal& start_indices, - std::unique_ptr* reshaped_start_indices) { + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -834,16 +863,16 @@ static StatusOr> ReshapedGatherIndices( new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_start_indices, start_indices.Reshape(new_shape)); - return std::cref(**reshaped_start_indices); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_start_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), @@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; @@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( @@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = absl::make_unique( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator .Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr result; + Literal result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); @@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template -StatusOr> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) @@ -1173,57 +1206,55 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = absl::make_unique(keys_literal.shape()); - result_keys_literal->PopulateR1(absl::Span(result_keys)); - auto result_values_literal = - absl::make_unique(values_literal.shape()); - result_values_literal->PopulateR1( + Literal result_keys_literal(keys_literal.shape()); + result_keys_literal.PopulateR1(absl::Span(result_keys)); + Literal result_values_literal(values_literal.shape()); + result_values_literal.PopulateR1( absl::Span(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; - std::unique_ptr result_tuple; + Literal result_tuple; if (rank == 1) { auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); + result_tuple = + LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = absl::make_unique(keys_literal.shape()); - auto values_result_literal = - absl::make_unique(values_literal.shape()); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); + .Reshape({r1_length})); TF_ASSIGN_OR_RETURN(auto values_r1_slice, values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + .Reshape({r1_length})); + auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); + r1_result_pair.first.Reshape({1, r1_length})); TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + r1_result_pair.second.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values, {0, 0}, {row, 0}, {1, r1_length})); } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template -StatusOr> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { case F32: return EvaluateSortInternal(sort, keys_literal, @@ -1242,9 +1273,9 @@ StatusOr> EvaluateSortCurried( } } -StatusOr> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried(sort, keys_literal, values_literal); @@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - const HloModule& module, - absl::Span> arg_literals); - -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( + +template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span> arg_literals); + absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - HloInstruction* instruction, - absl::Span> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 72252bafc767ee03e2c91b6beedf49c7e0902531..21e676d671af08d1626ca6f157db63bf8d23ae0b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloModule& module, absl::Span arg_literals); + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloComputation& computation, - absl::Span arg_literals); + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - HloInstruction* instruction, absl::Span arg_literals); + StatusOr Evaluate(HloInstruction* instruction, + absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr> Evaluate(HloInstruction* instruction); + StatusOr Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr> EvaluateWithSubstitutions( + StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions); - StatusOr> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, const Literal& lhs, - const Literal& rhs); + StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use FlatMap any more. + std::unordered_map evaluated_; private: template - static StatusOr> ElementWiseUnaryOpImpl( + static StatusOr ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { @@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape())); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 102ebb24abf13ab78db4fa3315ded98b4fefb4b6..01e88566a55dbfddaaec5db6100327a8c1db398b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, evaluator_ = absl::make_unique(); } - std::unique_ptr Evaluate( - absl::Span arg_literals = {}) { + Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr lhs, - std::unique_ptr rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - result->EachCell( - [&](absl::Span indices, NativeT value) { - std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); - }); + result.EachCell([&](absl::Span indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { @@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), -7); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), -31); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); @@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 1, 8); expected_array.FillWithYX( Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1906,12 +1901,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1930,12 +1925,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1954,14 +1949,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3( + LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), start_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1980,15 +1974,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -2008,15 +2001,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -2035,12 +2027,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -2059,13 +2050,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2084,11 +2074,10 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2108,12 +2097,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr start_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2138,15 +2127,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2171,15 +2158,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2205,15 +2191,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2239,15 +2223,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { @@ -2273,17 +2255,15 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2( + LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), - ErrorSpec{0.1, 0.01})); + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2309,15 +2289,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2343,15 +2321,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2376,21 +2353,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + Literal expected = LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, @@ -2416,21 +2390,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + Literal expected = LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2455,16 +2426,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2489,17 +2458,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2524,13 +2490,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *operand, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + operand, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2557,16 +2521,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - std::unique_ptr expected = - LiteralUtil::CreateR1({10, 61, 32}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1({10, 61, 32}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2603,11 +2564,29 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr arg = LiteralUtil::CreateR1( + Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); - std::unique_ptr expected = - LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); +} + +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 63303aef1e37cd9da5ee1139ce72a48e398a61c7..8fb17a00330deae8c004a8d491b46bf7adb84241 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span out_index) { + result.Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = absl::make_unique(dot->shape()); + Literal result(dot->shape()); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span result_index) { + result.Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = absl::make_unique(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate( [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = @@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set(target_index, - evaluated_operand.Get(input_index)); + result.Set(target_index, + evaluated_operand.Get(input_index)); return true; }; @@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> MapImpl(HloInstruction* map) { + StatusOr MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { - std::vector> arg_literals; + result.Populate([&](absl::Span multi_index) { + std::vector arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get({}); + return computed_result.Get({}); })); return std::move(result); } @@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = absl::make_unique(keys_literal.shape()); - result_literal->PopulateR1(absl::Span(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + Literal result_literal(keys_literal.shape()); + result_literal.PopulateR1(absl::Span(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); return result_literal; }; @@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = absl::make_unique(keys_literal.shape()); + Literal result_literal(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + .Reshape({r1_length})); + auto r1_result = sort_r1(r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + r1_result, {0, 0}, {row, 0}, {1, r1_length})); } parent_->evaluated_[sort] = std::move(result_literal); } @@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector, 1> results(num_args); + absl::InlinedVector results(num_args); for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique(result_shape); + results[i] = Literal(result_shape); } Status eval_status; @@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate( + TF_RETURN_IF_ERROR(results[input].Populate( [&](absl::Span multi_index) { if (!eval_status.ok()) { return init_scalars[input]; @@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Evaluate computation with specified literal operands. - absl::InlinedVector, 1> - embedded_operands; + absl::InlinedVector embedded_operands; for (ReturnT value : result_values) { embedded_operands.push_back( LiteralUtil::CreateR0(value)); @@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_operands.size()); std::transform(embedded_operands.begin(), embedded_operands.end(), embedded_operands_ptrs.begin(), - [](const std::unique_ptr& ptr) { - return ptr.get(); - }); + [](Literal& literal) { return &literal; }); - TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, + TF_ASSIGN_OR_RETURN(Literal computed_result, embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on @@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. if (!has_tuple_output) { - result_values[0] = computed_result->Get({}); + result_values[0] = computed_result.Get({}); } else { for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get( + result_values[i] = computed_result.Get( /*multi_index=*/{}, /*shape_index=*/{i}); } } @@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!has_tuple_output) { parent_->evaluated_[reduce] = std::move(results[0]); } else { - auto tuple_result = absl::make_unique(reduce->shape()); + Literal tuple_result(reduce->shape()); for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); } parent_->evaluated_[reduce] = std::move(tuple_result); } @@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = absl::make_unique(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( + TF_RETURN_IF_ERROR(result.Populate( [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; @@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); + bool selected = !computed_result.Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr computed_result = + auto scattered = result.Get(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); + result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); @@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span output_index) { + result.Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(curr_val); const auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = + Literal computed_result = embedded_evaluator .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get({}); + result_val = computed_result.Get({}); }); return result_val; @@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = [&](absl::Span update_window_index, @@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result_value_literal = - LiteralUtil::CreateR0(result->Get(input_index)); + LiteralUtil::CreateR0(result.Get(input_index)); auto update_value_literal = LiteralUtil::CreateR0(updates.Get(update_index)); - std::unique_ptr updated_result = + Literal updated_result = embedded_evaluator .Evaluate( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set(input_index, updated_result->Get({})); + result.Set(input_index, updated_result.Get({})); return true; }; @@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (ShapeUtil::Rank(iota->shape()) > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], - result->Broadcast(iota->shape(), {iota->iota_dimension()})); + result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); parent_->evaluated_[iota] = std::move(result); @@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - result->Set(result_index, - update_literal.Get(update_index)); + result.Set(result_index, + update_literal.Get(update_index)); return true; }; @@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr> ElementWiseUnaryOp( + StatusOr ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr> ElementWiseBinaryOp( + StatusOr ElementWiseBinaryOp( HloInstruction* instruction, const std::function& binary_op) { @@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> ElementwiseTernaryOp( + StatusOr ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0345a2a5f871b0e5d09423d2c3bb48961e1ede2f..287ba84b3b24d3ec6dc21d157205ebc6a987c7d7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -123,6 +123,10 @@ class NodeFilter { // We arbitrarily set this as the boundary between "large" and "small" // instructions. bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { + return true; + } return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; } @@ -465,9 +469,8 @@ stylesheet=< string graph_label = StrCat(label_, "
Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 25ae344ea5e8089e26765a5b7bd1b39123ffd454..e905f2983a43189eeb06824cf3078c235ab07925 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -250,7 +250,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -505,6 +505,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -527,7 +528,7 @@ StatusOr> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr HloInstruction::CreateConstant( - std::unique_ptr literal) { + Literal literal) { return absl::make_unique(std::move(literal)); } @@ -2096,7 +2097,7 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5581c17c2d03cb01f63da7bbf06c6a3b9c972734..4f6cac1396c16beb5cebf909032dead711d77a61 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -82,6 +82,7 @@ class HloPrintOptions { print_operand_shape_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -94,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -108,6 +110,7 @@ class HloPrintOptions { .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -153,6 +156,12 @@ class HloPrintOptions { return *this; } + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -190,6 +199,9 @@ class HloPrintOptions { bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -205,6 +217,7 @@ class HloPrintOptions { bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; @@ -346,8 +359,7 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr CreateConstant( - std::unique_ptr literal); + static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. static std::unique_ptr CreateIota(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index fb7345a2ade972569ae2a031994bd7f71b034fc6..e92882c22a6ef1dd43440d3c94c7d233c9a4fb5d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -845,8 +845,8 @@ std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; @@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -893,7 +893,8 @@ std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple @@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index c3a7801164737a7f1858b6a09a8af7896b3f4a8c..2d7bc83855e761ed313d831a1252a54130910bbe 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + absl::optional literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc similarity index 97% rename from tensorflow/compiler/xla/service/hlo_scheduling.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 9bfb0af96ce73bbecb5e670430d7f6b7464c34d9..c7ec88d450712b0831971139f165934ef5524845 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include @@ -582,4 +582,22 @@ StatusOr ScheduleComputation( size_function, nullptr, empty_map); } +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h similarity index 71% rename from tensorflow/compiler/xla/service/hlo_scheduling.h rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 54e32340ba456a4fe5b1a939a8c8b81ad0813e2c..5e02868ebadaf06458f81e4f10ac04f882421ec8 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #include #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -86,6 +87,37 @@ StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloPassInterface { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloPassInterface { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr Run(HloModule* module) override; +}; + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc similarity index 95% rename from tensorflow/compiler/xla/service/hlo_scheduling_test.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 6afe51997e903aa3adf6d7eb3beae4c527de8d52..1b9e9bfc77c3ba91e5b878f4aa42d26d8267a49a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include @@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + // Verify that all instructions are in the sequence. const std::vector& sequence = - schedule.sequence(module->entry_computation()).instructions(); + module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". EXPECT_EQ(param, sequence.front()); EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering ordering(module->schedule()); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); } TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cfe906d9c578d2755fca31ab406da1262cddf13f..b3949f3a6d7176950c61cafb0830d1175f17758d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) { HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -312,22 +321,32 @@ StatusOr> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; + tensorflow::gtl::FlatSet computation_ids; + tensorflow::gtl::FlatSet instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 26fd1b243863850dda5ddac8f5c67fb214f5d927..3bc2d13781aa72738d695e37a02983ee82c6037d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -253,7 +253,7 @@ class HloModule { private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); const string name_; HloModuleConfig config_; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 98d20315e399c6b1a3979b5d11a89ef93869f4d9..f7be5cae2239e81d9aa1f5fb811a37c6086b028f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -36,23 +36,6 @@ namespace xla { namespace { -bool HasSendRecv(HloComputation* computation) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecv || - instruction->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (auto* sub_computation : instruction->called_computations()) { - if (HasSendRecv(sub_computation)) { - return true; - } - } - } - return false; -} - StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { bool changed = false; for (auto* computation : module->computations()) { @@ -68,9 +51,10 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { if (!ShapeUtil::IsTuple(xla_while->shape()) || while_body_root->opcode() != HloOpcode::kTuple || - HasSendRecv(while_body_comp)) { + while_body_comp->HasSideEffect() || + xla_while->while_condition()->HasSideEffect()) { // Only run DCE on tuple-shaped while loops where body root is Tuple, - // with no send/recv instructions. + // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); continue; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e4905fc13a4ef07aeaac255259fc6b86ba..bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9b56ef4643f2ca88e56456ae6c990161adb5085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(absl::string_view name, + std::unique_ptr module) + : name_(name) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector> HloModuleGroup::ConsumeModules() { + std::vector> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + HloModuleGroup(absl::string_view name, std::unique_ptr module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span> modules); + + // Returns the modules contained in the group. + const std::vector& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector> ConsumeModules(); + + string name() const { return name_; } + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs); + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebf790ba6f1b5f9a7d4be8a8324420dbe11793f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + HloModuleGroup group(TestName(), std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + std::vector> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 400bd4d94773e21bd08e78159415a734db50ca74..39f38b417ab0e8b54864176d8d1e0ad1a422eca6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" - #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 6b6005e7a5622fb1d6263c848aecd4834f62915f..00970bcda34209d33867099d0bcf3b2902d52ae8 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c54360b063724ed93895e1f2ad4eb8774fef1d57..11caa89c545e8fbfad96a9ab8e448a68a565e423 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -105,16 +105,13 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template - bool ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, } template -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b0e01abc14543205af71b4b0c0bdf3c..d9848cee0bfa904a90aea4626c3ee62c2cbb45b6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 0a0a6a323e6bb223bba305aaddbe1370da1916c2..bd6dd79b679729adb6691ef809b19f06c6d5dd05 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -27,15 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -1194,51 +1193,12 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run(HloModule* module, - HloSchedule* schedule, - int64 memory_limit_bytes, - RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The schedule is constructed entirely by this method. - TF_RET_CHECK(schedule->empty()); - +StatusOr HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial schedule of HLO instructions. - TF_ASSIGN_OR_RETURN(*schedule, - ScheduleModule(*module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(*schedule); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - - // RemoveUnnecessaryCopies only considers interference when determining - // whether it is legal to remove a copy. However, copies in the graph may be - // necessary for other reason such as preventing a constant from being live - // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. - // TODO(b/80249101): Break copy insertion into several passes and run each - // one once in the regular HLO pipeline. - TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); - - // The passes above can add and remove copies, update the schedule to - // account for these transformations. Newly added instructions will be - // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(schedule->Update()); - - TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(*schedule), module)); - } - + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1254,7 +1214,7 @@ StatusOr HloRematerialization::Run(HloModule* module, }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -1263,13 +1223,14 @@ StatusOr HloRematerialization::Run(HloModule* module, // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, schedule](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory( - node.computation(), - schedule->sequence(node.computation()).instructions())); + ComputePeakMemory(node.computation(), + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1287,9 +1248,10 @@ StatusOr HloRematerialization::Run(HloModule* module, // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), schedule, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1298,7 +1260,7 @@ StatusOr HloRematerialization::Run(HloModule* module, // After DCE, the module sequence may include instructions which no longer // exist. - TF_RETURN_IF_ERROR(schedule->Update()); + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1315,32 +1277,22 @@ StatusOr HloRematerialization::Run(HloModule* module, << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { + if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index fa0414b4728275e640cecc91ccbd2a7e4c9585ff..e2aaf18b3e482bbf777c594c7f5a22832be2ac17 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,17 +17,23 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloPassInterface { public: using ShapeSizeFunction = std::function; @@ -38,10 +44,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -49,51 +52,27 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // schedule: Should point to an empty HloSchedule. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - HloSchedule* schedule, RematerializationSizes* sizes, - CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. schedule should be an empty HloSchedule. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, HloSchedule* schedule, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation @@ -121,6 +100,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83cb113bfb81c9e3e712efc2e4a8431818ab0048..f7e82fb1f88e856305f6f481a451d4cd64ba4acf 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase { } StatusOr RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module, - HloSchedule* schedule) { + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - schedule, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &schedule)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 66ac1f66fd035074c69d070821a951fd0e357289..fa7f216321988137dcf9104a324f5f7789869aa5 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -118,16 +118,16 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const absl::Span> literals) { + const absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr> HloRunner::TransferLiteralFromDevice( +StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,7 +135,7 @@ StatusOr> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr> HloRunner::Execute( +StatusOr HloRunner::Execute( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -150,15 +150,15 @@ StatusOr> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -204,7 +204,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr>> HloRunner::ExecuteReplicated( +StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -290,9 +290,9 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -310,10 +310,10 @@ StatusOr>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector> exec_results; + std::vector exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 76d8b92bed484381a59d7f54e0a75bb7e75649ee..2e934bf66ae43ea412f242030b874dddb6d3722d 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector>* outfeed_values = nullptr; + std::vector* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -106,24 +106,23 @@ class HloRunner { StatusOr> TransferLiteralsToDevice( const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const absl::Span> literals); - StatusOr> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span literals); + StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr> Execute( - std::unique_ptr module, - const absl::Span arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr> Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. @@ -140,7 +139,7 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr>> ExecuteReplicated( + StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index eb52582bb5e17c2165251574a80ec348224584e6..1424569ac1f62e4b965876141f1eb40be4f15bea 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 069586a738d10b2a22c6032ca1d29a66a1ccdc6e..50f39cbcb55e29a2654ed8c745ea24ee2e0ab899 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1123,6 +1123,11 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0cac210c2413e979300e191cb54860bcd0ab79b5..8f0423bb1c72ceb209437116a898d027f4d2c657 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); auto module = CreateNewModule(); @@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 37b774b8a5d80643e5833480df12657a3a3ea5f2..06f0e1ed25e71659a61e6de8a84e52cf70064eae 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 9746d176ccf1f17353725515566a3dd79f3a79d4..df9cbab915cc037cec682238886fb524eaeb2c90 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -347,21 +347,19 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr TakeOwnership( - StatusOr> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - std::move(literal_or_error)); + StatusOr TakeOwnership(StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector> owned_tensors_; - std::vector> owned_literals_; + std::vector owned_literals_; tensorflow::gtl::FlatMap cache_; }; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 5695bc242057c037a1999e7d63f5b4f21b5f658a..7e967f035c1054e22d10790188a5a232ca8e751a 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloTestBase; +using InlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(InlinerTest, MapMax) { @@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. @@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8c907eae0cbe7c3764a2bfe8fed6b6098931de38..3fdc2cee9aad0fe70f66920f757ee5c52bba711f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector post_order_; + tensorflow::gtl::FlatMap post_order_index_; +}; + +} // namespace + +std::unique_ptr InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer) { + return absl::make_unique(computation); +} + StatusOr InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -306,111 +439,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); @@ -425,32 +478,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 00b658959a2cceeb30d2ec03f243119ec0a8ee47..c1fde8ecfc04792c6c17ebd83190486ef720175a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,6 +24,33 @@ limitations under the License. namespace xla { +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface { static bool IsExpensive(const HloInstruction& instruction); protected: + // Returns a FusionQueue that implements custom order of instructions being + // fused. The default implementation processes consumers in reverse post + // order. + virtual std::unique_ptr GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer); + // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. // Derived classes should define this method to specify which instructions diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 5dea12476849db6f7a9a9214398b4e57262aeda0..a06d6113e84630df14ff68280c248cccb9afaf06 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -73,30 +73,29 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector> arg_literals; + std::vector arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 69c7e426017649a0d37128e328c19440a6e64096..752a61476dd7892a2b7f531c4057015f48fc4758 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, @@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); @@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(&module(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { @@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest, } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index f0e2566a3f9ef5c0be8af46d3a16cd9c72793366..b27a92f2a0761a2bccd97eb2c0467ead27565c37 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span arguments, module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -812,7 +812,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -959,9 +958,9 @@ std::unique_ptr CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), *literal)); - *result->mutable_literal() = literal->ToProto(); + executor, arg->shape_with_layout(), literal)); + *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModule::CreateFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } @@ -1162,7 +1160,7 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1170,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 44c5248b150cff57546d3287869787f37c8975ba..1f62fad4c8079eba7013b3f647fe19bbc031fc77 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -271,7 +271,9 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc deleted file mode 100644 index dd53c7531bea4273b5f8dc1c993e7720eb1afeb2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/source_map_util.h" - -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace source_map_util { -namespace { - -Status InvalidParameterArgumentV(const OpMetadata& op_metadata, - const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - if (!op_metadata.source_file().empty()) { - absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), - op_metadata.source_line()); - } - return InvalidArgument("%s", message); -} - -} // namespace - -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidParameterArgumentV(op_metadata, format, args); - va_end(args); - return result; -} - -Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) { - va_list args; - va_start(args, format); - if (executable != nullptr && executable->has_module()) { - const HloModule& module = executable->module(); - const HloComputation& computation = *module.entry_computation(); - HloInstruction* param = computation.parameter_instruction(parameter_number); - const OpMetadata& metadata = param->metadata(); - Status result = InvalidParameterArgumentV(metadata, format, args); - va_end(args); - return result; - } - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -} // namespace source_map_util -} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index b8d2d546e5d4dc67e3f314dfc6dcd4e8df5451c5..a21e586efadb85d18e88e44999283b28f7f65eac 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr> TransferManager::TransferLiteralFromDevice( +StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr> ret; + StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); @@ -63,7 +63,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferLiteralFromDevice( @@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr> TransferManager::TransferArrayFromDevice( +StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { - StatusOr> ret; + StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -122,7 +122,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferArrayToDevice( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 21725946b3629a4495d8ad6cc1529d712d22e0af..f952e64af2b675b9c0f8a30e9a2bc3c855e34efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -57,7 +57,7 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr> TransferLiteralFromDevice( + virtual StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, @@ -113,9 +113,9 @@ class TransferManager { Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 2b2a2eb42a477c6c3896ef4e267a04f59d30f0bd..e9a07b14ed685fa4388aca583395370a60176cca 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872da6bd985d95c2abc9519662c838a3f5..516754e2110ee50a597818c4a8bcfbfbb76c5cec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index c3c2603c7eb58d3e57346d2ea1e0058f8e5d7fe8..541b117e0299c94de330604ec5c16e20f07c425f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -183,8 +183,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -197,31 +196,27 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); // The initial value of the induction variable. - std::unique_ptr indvar_iter_val = - std::move(indvar_init_result).ValueOrDie(); + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); for (int64 trip_count = 0; trip_count != max_value_returned + 1; ++trip_count) { auto* while_cond = while_op->while_condition(); auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + StatusOr result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == absl::Span{false}) { + if (result.ValueOrDie().data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr> indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, - {{while_body_indvar, indvar_iter_val.get()}}); + StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 52c895e8d4b2aa55b55df41b7139b00c576d6e99..df610102b4c7fa08c0b7030124939009130f89f4 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -224,14 +224,13 @@ class ShapeTree { // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); - return iterator(&nodes_, typename std::vector::iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.begin() + (element - &nodes_[0]); + return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); - return iterator(&nodes_, - typename std::vector::const_iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); + return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 9772c06bce32cef0d79a036b525c3606ea60e31b..96c80fd577e2601c972e374a153f4f0706902ec2 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return count; } +/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type) { + if (shape.element_type() == primitive_type) { + return true; + } + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } + } + return false; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8234fcdd3f57978b94630d4e2880826dd678389f..623ae39de819ebecdc8aee27a2b31176421ef020 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -180,6 +180,10 @@ class ShapeUtil { // As ElementsIn(), but recurses through tuples. static int64 ElementsInRecursive(const Shape& shape); + // Returns true if shape has the primitive type, recurses through tuples. + static bool HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 6ca4085aaf3bd1c181da3b94aa6c570e21172d0a..c622ecdca1fd66604d1a6ceaf705f2e70edaee55 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } +TEST(ShapeUtilTest, HasPrimitiveType) { + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}), + S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}), + S16)); +} + TEST(ShapeUtilTest, IsZeroElementArray) { EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d0bda45cf8e1a8ea6530f9996b7fef0834a1b0dc..30e3077edb93e1ac740c1d863aacce975ad4c8a5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -647,6 +647,7 @@ xla_test( ], shard_count = 48, tags = [ + "broken", "manual", "notap", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0bf4556b437fb1717a9c9773834fa3031cfbd6ea..c257566fb218d4769aec0c793efb9256b023b7ea 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -41,7 +41,6 @@ limitations under the License. namespace xla { namespace { - class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; @@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, @@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); @@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, @@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); @@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000ULL)}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); Lt(lhs_param, rhs_param); - ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); + ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); } TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { @@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); - auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); + auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); @@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); - auto param = Parameter(&b, 0, param_literal->shape(), "param"); + auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } @@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); @@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); @@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); @@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); @@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + client_->TransferToServer(literal2).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); @@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); @@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); @@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); + Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = - client_->TransferToServer(*literal3).ConsumeValueOrDie(); + client_->TransferToServer(literal3).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); - auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); @@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, @@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); @@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); - auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, @@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( @@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::exp(input_literal->Get({i}))); + expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::log(input_literal->Get({i}))); + expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{true, true}, {true, false}}), + LiteralUtil::CreateR2({{true, false}, {false, false}})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { @@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = ConstantLiteral(&builder, *a_literal); + Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); - auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); + auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index ac90a3adb6dbad30e3ef0b11438fb9a6fd6f8574..bc2ba151a38f1ab000b342dcd4bdd8f53d9ce9a9 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), - LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), - LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1(std::vector(1, 15.0f)), + LiteralUtil::CreateR1(std::vector(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1({0, 0}).get(), - LiteralUtil::CreateR1({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1({0, 0}), + LiteralUtil::CreateR1({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), - LiteralUtil::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1(mean), + LiteralUtil::CreateR1(var)}); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D expected = normalized; std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1(grad_scale).get(), - LiteralUtil::CreateR1(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1(grad_scale), + LiteralUtil::CreateR1(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 65589b0d6af2ffca26776541eb05a093f43e0a9a..e9728e636f0ee032416b2da17a3ea83c5bb18083 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, - {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) - .get(), + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}), LiteralUtil::CreateR1( - {static_cast(4), static_cast(5)}) - .get(), + {static_cast(4), static_cast(5)}), LiteralUtil::CreateR1( - {static_cast(5), static_cast(5)}) - .get()}); + {static_cast(5), static_cast(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, - {{static_cast(3.f)}, {static_cast(3.f)}}}}) - .get(), + {{static_cast(3.f)}, {static_cast(3.f)}}}}), LiteralUtil::CreateR1( - {static_cast(0), static_cast(0)}) - .get(), + {static_cast(0), static_cast(0)}), LiteralUtil::CreateR1( - {static_cast(16), static_cast(20)}) - .get()}); + {static_cast(16), static_cast(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index fe4267c73bd170f22a0456533f45e50be823a80b..dde19fb65d65064c9452a6ac49c70e20cf113336 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = - client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + client_->TransferToServer(r3_data).ConsumeValueOrDie(); return r3_global_data; } @@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = - client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + client_->TransferToServer(r2_data).ConsumeValueOrDie(); return r2_global_data; } @@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R3ImplicitBroadcastSpec { @@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, - {r3_implicit_global_data.get(), r3_global_data.get()}, + &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, ErrorSpec(1e-7, 1e-7)); } @@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R2ImplicitBroadcastSpec { @@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, + &builder, expected, {r2_implicit_global_data1.get(), r2_global_data.get(), r2_implicit_global_data2.get()}, ErrorSpec(1e-6, 1e-6)); @@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); auto expected = LiteralUtil::CreateR3( {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { @@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { @@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { @@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 74d4d2eb10c32b270a83aa04dd2e6025d7a56c26..9966e4606ef7f104487182e0240e64e4c9e4d834 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, error_spec_)); } @@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(*result, {0}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + LiteralSlice(result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(*result, {1}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + LiteralSlice(result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, error_spec_)); } @@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, error_spec_)); } @@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_)); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index b1d18210eaafdfec0920c0cccaa0dfdbd6de5609..8b31e53707eee456e09adfe9fb76f03a8855056d 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = - ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); - auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({1.0f, 2.0f})); auto y = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); + client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } @@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); - Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); + auto tuple = LiteralUtil::MakeTuple({&elem}); + Call(&builder, callee, {ConstantLiteral(&builder, elem)}); - ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); + ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index a4eb57fc7b9abd460a7d158d0dc629eba88018cd..2f1510ff6969757f8091e9c043b61cb2a467ccd5 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); auto param0_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto param1_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); @@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto computation = computation_status.ConsumeValueOrDie(); auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie(); auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = - client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + client_->TransferToServer(f32_4_literal).ConsumeValueOrDie(); auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie(); // Match auto status = client_->Execute( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 8a236db0ff2f63332892de822461dd1cc17276ca..fbdf0fcb6543f09dedefef55cfe0f8a5d9067d5a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -101,7 +101,7 @@ StatusOr> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -113,7 +113,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. @@ -121,8 +121,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr> -ClientLibraryTestBase::ExecuteAndTransferReference( +StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, absl::Span arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const string& error_message)>& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, + verify_output(actual, absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); @@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } @@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } @@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr, std::unique_ptr>> +StatusOr> ClientLibraryTestBase::ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's @@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr @@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 22dfdfb0e4c67cc06fa748177c75cf35572196c8..9d32f4f5174a57a53a9d3e6477b46fa4de852f7f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> Execute( XlaBuilder* builder, absl::Span arguments); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test { // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected values of the // computation. - StatusOr> ExecuteAndTransferReference( + StatusOr ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*LiteralUtil::CreateFromArray(argument), builder); + return AddParam(LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), + return CreateConstantFromLiteral(LiteralUtil::CreateR0(value), builder); } @@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - absl::Span arguments); + StatusOr> ComputeValueAndReference( + XlaBuilder* builder, absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -412,9 +411,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -438,9 +435,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, absl::Span expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -464,9 +459,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -490,9 +485,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -516,9 +511,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -542,13 +537,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR0(value); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -556,13 +551,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR1(values); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -570,13 +565,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -584,13 +579,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf489db97223e2918414daf5de88bece64..6f2ca84bb646e88af221ab80b727911ff7d990eb 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 03d56964998f9abea21d6f82dee8faf86f9fe1d4..6ef7ca035f75966bef12c7abcb55cb59e9b73655 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase { absl::Span arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - std::unique_ptr result = + Literal result = client_ ->ExecuteAndTransfer(computation, arguments, /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); + LiteralUtil::CreateR0(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase { ->Execute(computation, arguments, &execution_options_, &execution_profile) .ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data_handle).ConsumeValueOrDie(); + Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); + LiteralUtil::CreateR2(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = - client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(rowmaj_array).ConsumeValueOrDie(); auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = - client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 8226b6de3f780197bc0f1145b617dba99803927f..3b0414a6045a7c5f4f75948d8ccf2775c575626e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr> ComputeConstantLiteral( - Client* client, const XlaOp& operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + StatusOr ComputeConstantLiteral(Client* client, const XlaOp& operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test { XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder, nullptr)); - return literal->Get({}); + return literal.Get({}); } bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { @@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR1({4, 6}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR0(5); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e..9811a015e91d866d6f4de6ebb6dac536ed6c7e06 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "z"); auto bcast = Broadcast(y, {5}); @@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "y"); auto y_bcast = Broadcast(y, {1, 5, 7}); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 25d10ab00af11b8ebb8147917e7cdbb21f9a42c4..32cac499c7439af80bafb88ac61b0b078f589599 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), - LiteralUtil::CreateR0(25.0f).get()}), + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), + LiteralUtil::CreateR0(25.0f)}), {pred_arg.get()}, error_spec_); } @@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, CreateR1TupleFloorComputation()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), - LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({13.0f, 16.0f}), + LiteralUtil::CreateR1({26.0f, 30.0f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(true).get(), - LiteralUtil::CreateR0(12.2f).get(), - LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(12.2f), + LiteralUtil::CreateR1({12.8f, 14.6f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(46.6f).get(), - LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) - .get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), - LiteralUtil::CreateR0(9.3f).get()}) - .get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f})}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)})}), {pred_arg.get()}, error_spec_); } @@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), - LiteralUtil::CreateR0(b).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), {x_arg.get(), y_arg.get()}, error_spec_); }; @@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is true case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(true))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(true)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); @@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is false case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(false))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(false)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 49375748319ad5fe40db507a034ec4b07adb7e84..72ff1e74a47c8584cb5336c86a1c978c4637a902 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D( Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *input_literal); + ConstantLiteral(&builder, input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})})); - std::unique_ptr result = - ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); + Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + LiteralSlice(result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 7a203d6873dbb5b69f96c50048c2c5ff3150c544..5f063e67847487f1d18bf4ee80b1634ebdf4183a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); @@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 38b6da4fa96b0f6b7ed2d56852eb3ab2872f3520..fd98bf29b8a06d7476d51174b61c6268750db2ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index d2c6478b02423c93860244bc5eb91e652a3eac2e..070b092d18930027e215cb43ff917e36cac99f12 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + client_->TransferToServer(filter_r5).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r5, + ComputeAndCompareLiteral(&builder, expected_r5, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(16029), static_cast(16218), static_cast(16407), @@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { static_cast(18369), static_cast(18576), static_cast(18783), static_cast(19620), static_cast(19836), static_cast(20052), static_cast(20925), static_cast(21150), static_cast(21375)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(5076), static_cast(5160), static_cast(5244), static_cast(5328), static_cast(6164), static_cast(6264), static_cast(6364), static_cast(6464), static_cast(7380), static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(param0)), - std::move(*LiteralUtil::CreateFromArray(param1))}, + {LiteralUtil::CreateFromArray(param0), + LiteralUtil::CreateFromArray(param1)}, error_spec_); } @@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = - expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + client_->TransferToServer(input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } @@ -891,9 +890,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { Array4D filter_data(1, 1, 1, 2); filter_data.FillIota(10); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}); + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}); +} + +XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100}); + Array4D input_data(1, 64, 100, 100); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321); + Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64}); + Array4D filter_data(7, 7, 1, 64); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = ConstantR4FromArray4D(&builder, filter_data); + + // Specify bf01_01io->bf01 as dimension numbers. + ConvolutionDimensionNumbers dnums; + // Input + dnums.set_input_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + // Kernel + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + ConvGeneral(input, filter, /*window_strides=*/{1, 1}, + /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, + /*feature_group_count=*/64); + + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)}, + error_spec_); } class ConvolutionHloTest : public HloTestBase {}; diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 6784c16715da72d337edf70fa51db42c59404136..ba3e9c436e3cfa574a07e881a187ff4c7d6243a1 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = - weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = ConstantLiteral(&builder, *weights_literal); + weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = ConstantLiteral(&builder, weights_literal); auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = Rev(weights, {2, 3, 4}); ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { @@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = - activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = ConstantLiteral(&builder, *activations_literal); + activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); + auto activations = ConstantLiteral(&builder, activations_literal); auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = ConvGeneralDilated(activations, gradients, @@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 526626c1ddd902a4ba6c608f2b9355cece9ec833..1407e68d9a336b6bb1c960711015430f872aa912 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase { protected: void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + auto constant = + builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); + Literal result = ExecuteAndTransfer(std::move(module), {}); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); + TestCopyOp(LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); + TestCopyOp(LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { @@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = LiteralUtil::CreateR0(42.0); - Shape shape = literal->shape(); + Shape shape = literal.shape(); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {literal.get()}); - LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {&literal}); + LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = - literal->mutable_shape_do_not_use()->mutable_layout(); + Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); @@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. - LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, error_spec_); } @@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + Literal literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR3EqualArray3D(a, *result); + LiteralTestUtil::ExpectR3EqualArray3D(a, result); } void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, @@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + Literal literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR4EqualArray4D(a, *result); + LiteralTestUtil::ExpectR4EqualArray4D(a, result); } XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { @@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { XlaBuilder builder(TestName()); Parameter(&builder, 0, in_shape, "input"); - auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index d12a4e7fcd7813775a81677bcaa07af60ff9b477..410732c07b7b6d3ece33ab11f4778241dc53ca50 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); - EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); + EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); } XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { @@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ( - *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0, &literal1})); } // On the GPU backend, constants get special handling. Someone might pass a @@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 6f7fc0e6e52a69387a4c491871b6fcd97ac638b6..a693fa35954bcb2d95074c94d0aa3eabc1d5fd62 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { @@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, @@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest, module->AddEntryComputation(b.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( - Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index eb15fc0593adf2d1bd84da4d0f708b6244f0fb33..e0f23b0fa807ca27038afa2eec5f739508e3d5bd 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); handles1[0].reset(); handles1[1].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 5873516442fa63de47360acaa353abb3a97fe881..0171f515839d556827f0723772214d175939d386 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), + LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = - LiteralUtil::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( - param.dot_lhs_row_major))); + Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = - client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie(); std::unique_ptr> dot_rhs_data = MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); - std::unique_ptr dot_rhs_lit = + Literal dot_rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = - client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie(); std::unique_ptr> addend_data; - std::unique_ptr addend_lit; + Literal addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { @@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() { addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); - addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie(); } XlaBuilder builder(TestName()); @@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{9.0f, 10.0f}, {11.0f, 12.0f}}, {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) @@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) .ConsumeValueOrDie(); @@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { auto lhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 9bf3767ca3e229cd3eb37c1f51c526c7dd2bf0f8..7501c6d957e7afe99b8c530e5f0d575f818367da 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie(); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*LiteralUtil::CreateR0(input_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(input_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_value = - std::move(*LiteralUtil::CreateR0(update_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(update_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_value = - std::move(*LiteralUtil::CreateR0(expected_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(expected_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, absl::Span expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR1(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { - std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << literal->ToString(); + Literal literal = LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal.ToString(); } }; @@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = ConstantLiteral(&builder, *input_literal); + auto input = ConstantLiteral(&builder, input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); @@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) { auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), *start_indices_literal, buffer)); + stream.get(), start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 5116e60ca63ef5f94b25b15e6616086fb9e44bbb..b08ece0e63e9472f657b49b533511e9b192d3212 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index bf1de02ba9dbd97db9ee31484402fe9b92385219..51b50d456e496c9c01c38fb8539bb3737de16937 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal->Set({i - begin}, 0.0f); + input_literal.Set({i - begin}, 0.0f); } else { - input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, tensorflow::bit_cast(i)); } } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal->Get({i}))); + expected_result.push_back(evaluate_op(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 7cb2f0cedfc2e74386bb3c01ca0b838e7cdcbce9..9c94acb437e9fc948a4255f7112e2e7a40cfa5fb 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -117,9 +117,9 @@ class FusionTest : public HloTestBase { auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } } @@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.5}, {2.72}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { // Every element of result should be y = x^2 = 4.0. for (int i = 0; i < rand_dim0_size; ++i) { for (int j = 0; j < dim1_size; ++j) { - EXPECT_EQ(4.0, result->Get({i, j})); + EXPECT_EQ(4.0, result.Get({i, j})); } } } @@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(5), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{7}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 2, 1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-3, -2, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-2, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { @@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(-15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({8}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -782,19 +782,17 @@ ENTRY main { } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, - test_runner_.Execute(std::move(module), {operand.get()}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + test_runner_.Execute(std::move(module), {&operand}, + /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - *result)); + LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + result)); } class FusionClientLibraryTest : public ClientLibraryTestBase {}; @@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // where overflow is OK. Array2D arr(32, 32); arr.FillUnique(); - std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(*l1, &b); + XlaOp p0 = AddParam(l1, &b); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); sum = sum + p0 * pN * pN; } @@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) { auto param0_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = - client->LiteralToShapedBuffer(*param0_literal, device_ordinal) + client->LiteralToShapedBuffer(param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = - client->LiteralToShapedBuffer(*param1_literal, device_ordinal) + client->LiteralToShapedBuffer(param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = - client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + client->LiteralToShapedBuffer(param2_literal, device_ordinal) .ConsumeValueOrDie(); // Build executable. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 6d634980449268e509d87ee064fbaaaf59abd195..daa89398a697af9149797d621c3bdca80a00aedd 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -58,10 +58,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -79,10 +79,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -100,11 +100,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -122,11 +121,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -144,11 +143,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -166,13 +165,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -190,13 +188,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -214,10 +211,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -235,11 +232,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -257,9 +253,9 @@ ENTRY main { slice_sizes={1, 0} } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -281,11 +277,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -307,11 +303,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -333,11 +329,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -359,11 +355,11 @@ ENTRY main { ROOT result = u32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -381,10 +377,10 @@ ENTRY main { slice_sizes={1,3,2} } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -402,9 +398,9 @@ ENTRY main { slice_sizes={1} } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -422,10 +418,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -446,10 +442,10 @@ ENTRY main { ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -470,11 +466,10 @@ ENTRY main { ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -495,11 +490,11 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -520,13 +515,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, @@ -548,13 +542,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -575,10 +568,10 @@ ENTRY main { ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -599,11 +592,10 @@ ENTRY main { ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr operand_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); + client_->TransferToServer(LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::vector> result_data, client_->ExecuteParallel(computation_instances)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, - *result_literal); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 3df99aac7d42610ae90100c46d5cf0809ee569a0..bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -136,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { return debug_options; } -StatusOr> HloTestBase::Execute( - std::unique_ptr module, absl::Span arguments) { +StatusOr HloTestBase::Execute(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } -std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) .ValueOrDie(); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -188,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( TF_ASSIGN_OR_RETURN(auto reference, reference_runner_.Execute(std::move(reference_module), arguments, run_hlo_passes)); - return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, error); } @@ -223,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( ::testing::AssertionResult HloTestBase::RunAndCompare( std::unique_ptr module, const optional& error, const std::function& reference_preprocessor) { - const auto& fake_arguments = - MakeFakeArguments(module.get()).ConsumeValueOrDie(); + auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompare(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -243,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -277,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return test_runner_ .Execute(std::move(module_or_status.ValueOrDie()), fake_argument_ptrs, /*run_hlo_passes=*/true) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 21d77c0cc42a07eedc834775b294872c713a33c0..0ae4bdc104d656946d45008adec9ea3960984545 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -115,16 +115,16 @@ class HloTestBase : public ::testing::Test { } // Executes the given module and return the result as a Literal. - StatusOr> Execute( - std::unique_ptr module, absl::Span arguments); + StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. - std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); // Executes the given hlo module on two backends and compares results. // diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 96f72212f35f5e6e98e2dc24fd9a87891a326e8f..43cca91f64b2c0fbfde5054a361cf0f95302c23d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -155,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( absl::Span expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2(expected), actual)); } template @@ -176,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2(expected), actual, error)); } template @@ -223,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3(expected), actual, error)); } template @@ -232,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 4151bfae0332ffc706ba730d181c487eabab856f..b6f9b8156b51144e4f74d285b1e4111d098f13c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,11 +31,11 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal lhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + Literal rhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(64), + LiteralUtil::CreateR0(42), }); - CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal"; }; ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); } @@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto two = LiteralUtil::CreateR0(2); auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); - CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four"; }; tensorflow::Env* env = tensorflow::Env::Default(); @@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal->ToString()); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal->ToString()); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal->ToString()); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto expected = LiteralUtil::CreateR1({1, 2, 3}); auto actual = LiteralUtil::CreateR1({4, 5, 6}); - ::testing::AssertionResult result = - LiteralTestUtil::Equal(*expected, *actual); + ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); EXPECT_THAT(result.message(), @@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { @@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) { {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { @@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); - EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 237a4a361e386e24c2897c42602eb60ca7234731..dbdd20daf0c3a54ed7b6e2a9d3fb73274d77474a 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { DefaultExecutableBuildOptions(), options); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_); // At least one allocation should have been performed when executing the // computation. @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 1a823cf189b310c62c735419936544ea99fcfbaf..a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(123.f, ShapedBufferToLiteral(result), error_spec_); } @@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(165.f, ShapedBufferToLiteral(result), error_spec_); } @@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR1Near({}, ShapedBufferToLiteral(result), error_spec_); } @@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { @@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. ScopedShapedBuffer result_param_swap = ExecuteLocallyOrDie(computation, {&y_array, &x_array}); - LiteralTestUtil::ExpectR2Near( - {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_param_swap), error_spec_); + LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, + ShapedBufferToLiteral(result_param_swap), + error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. @@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_rowmaj), + ShapedBufferToLiteral(result_rowmaj), error_spec_); } @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {2})); + LiteralSlice(result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); + LiteralSlice(result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); + LiteralSlice(result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralSlice(result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), - LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}); + auto y_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}})}); - auto x_buffer = LiteralToShapedBuffer(*x_literal); - auto y_buffer = LiteralToShapedBuffer(*y_literal); + auto x_buffer = LiteralToShapedBuffer(x_literal); + auto y_buffer = LiteralToShapedBuffer(y_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); @@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) - .get(), - LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}), + LiteralUtil::CreateR1({222.0, -2.0, 10.0})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); + Literal result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); + LiteralSlice(result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - LiteralSlice(*result_0_literal, {1})); + LiteralSlice(result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); + Literal result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - LiteralSlice(*result_1_literal, {0})); + LiteralSlice(result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - LiteralSlice(*result_1_literal, {1})); + LiteralSlice(result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. - std::vector> arg_elements; + std::vector arg_elements; for (int i = 0; i < kElementCount; ++i) { arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } - std::unique_ptr arg_literal = - LiteralUtil::MakeTupleOwned(std::move(arg_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); } } @@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::vector> outer_tuple_elements; + std::vector outer_tuple_elements; for (int i = 0; i < kFanout; ++i) { - std::vector> inner_tuple_elements; + std::vector inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } @@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { } auto arg_literal = LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { - LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), - error_spec_); + LiteralTestUtil::ExpectR0Near(i + j + i * kFanout + j, + LiteralSlice(result_literal, {i, j}), + error_spec_); } } } @@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); + Literal arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { - std::vector> arg_vector; + std::vector arg_vector; arg_vector.push_back(std::move(arg_literal)); arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal(165.0, - LiteralSlice(*result_literal, index)); + LiteralSlice(result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { DefaultExecutableRunOptions().set_device_ordinal(d)); EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + ShapedBufferToLiteral(result)); } } } @@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); - LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + LiteralTestUtil::ExpectR0Equal(42.0f, ShapedBufferToLiteral(result)); } } @@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); + Literal tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - LiteralSlice(*tuple_literal, {0})); + LiteralSlice(tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - LiteralSlice(*tuple_literal, {1})); + LiteralSlice(tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; // Array shapes. - test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); - test_to_device_and_back(*LiteralUtil::CreateR0(true)); - test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(LiteralUtil::CreateR0(true)); + test_to_device_and_back(LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*LiteralUtil::MakeTuple({})); + test_to_device_and_back(LiteralUtil::MakeTuple({})); // Non-nested tuples. - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(12223.0)})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)})); // Nested tuple. - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()}) - .get(), - LiteralUtil::CreateR0(false).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)}), + LiteralUtil::CreateR0(false)})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456789000LL).get()})); + LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456789000LL)})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); Add(in, constant); - std::unique_ptr result; + Literal result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. thread.reset(); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { @@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + TF_ASSERT_OK_AND_ASSIGN(Literal result, local_client_->TransferFromOutfeedLocal( shape, local_client_->default_device_ordinal())); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } // Benchmark that measures the overhead of the LocalClient API when running a @@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) { auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, - buffer)); + ASSERT_IS_OK( + transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); const int kWarmups = 2; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index a8c68fc7fdbad30068af44606f559ca96603fe66..f90ef22d2d549f451f8af231aea834e9f097b12a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( +Literal LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { return local_client_->ShapedBufferToLiteral(shaped_buffer) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 90095c5d410f1561a1303a0f62f44d22ed5340f9..4027c7b124f8ac6e4b94600871ac32376a3e6467 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test { // Construct and return a literal containing the array represented by // shaped_buffer. - std::unique_ptr ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Execute the given computation on the local client. With and without // options. diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 0732e195d44d738b264361e43d38259c26a4116e..4d327a6fe9c45174a0666fd573a081e0cfe450d2 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + Literal param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, @@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, @@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, @@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( @@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( @@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); @@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( + Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( + Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1}); @@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1, 2}); @@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - std::unique_ptr param2_literal = + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); + Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = - client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + client_->TransferToServer(param2_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); - auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2"); Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( @@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) { Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); @@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, @@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( @@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + Literal param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index edb592f43ec778a3fe6e5ef936827dd612791760..3f278115e078877de1683574370df7790c2801fd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { }); Exp(data); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { }); Map(&builder, {data}, add_half, {0, 1}); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { }); Max(lhs, rhs); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -200,14 +200,12 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index c5e0b9b097d032df7ac75b27a2b72869a7d3c7ea..56aaeb0e6878737e6c689e8065d8f1e1871b3472 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -114,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); + Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = - ExecuteAndTransfer(std::move(hlo_module), - {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -178,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = - std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); + Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; @@ -218,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -247,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -280,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( @@ -324,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -356,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -389,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -422,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -456,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -492,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -530,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest, LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); - std::unique_ptr result = ExecuteNoHloPasses( - std::move(module), {param.get(), init1.get(), init2.get()}); + Literal result = + ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -565,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( @@ -576,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), - *result)); + result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index 0a0426adcbc1b5b89be0841fa2c4204e2b65abf4..f2460822a61fef11e5c35c731fa6eca5df72b60b 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { GetTupleElement(result_tuple, 0); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { VLOG(1) << "Transferring trip count to computation"; // Transfer number of iterations to Infeed. TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(1))); // Pick up value from outfeed { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 1); + EXPECT_EQ(r.Get({}), 1); } VLOG(1) << "Writing data to infeed"; // Transfer some stuff to Infeed for use inside of loop. TF_ASSERT_OK(local_client_->TransferToInfeed( - *LiteralUtil::CreateR1({10, 20}))); + LiteralUtil::CreateR1({10, 20}))); // Pick up value from outfeed { VLOG(1) << "Reading from body outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&xfeed_shape)); - EXPECT_EQ(r->Get({0}), 11); - EXPECT_EQ(r->Get({1}), 21); + EXPECT_EQ(r.Get({0}), 11); + EXPECT_EQ(r.Get({1}), 21); } { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 0); + EXPECT_EQ(r.Get({}), 0); } // Joins the thread thread.reset(); - EXPECT_EQ(comp_result->Get({}), 0); + EXPECT_EQ(comp_result.Get({}), 0); } XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { @@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { })); TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(true))); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&result_shape)); - EXPECT_EQ(r->Get({}), true); + EXPECT_EQ(r.Get({}), true); // Join the thread thread.reset(); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index cbeddffacfa4a0fc560e8b9f9a8d7bd23ff32e55..6e98167739c234fae335bcc9e024423e7fc87197 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*LiteralUtil::CreateR0(1.5), &b), + AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); @@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(2, 3, 3, 2); @@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); Pad(AddParam(input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(8, 5, 1, 1); @@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), - padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), - padding_config); + Pad(reduce, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f6c762e7a4bee91a26c4c2e033c3717fef6d91d0..dcb4c11c3ccab5992e1ea4fadf02fda8ff77e7ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + Literal param0_literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); @@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + Literal param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), @@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); @@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); @@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters // @@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + Literal literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + Parameter(&builder, 1, literal1.shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1.shape(), "param2"); // This add is unused. Add(param1, param2); @@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + Literal literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( - client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + client_->TransferToServer(literal).ConsumeValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR0(i); + Literal literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); - parameter_shapes.push_back(literal->shape()); + parameter_shapes.push_back(literal.shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); + Literal bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( - std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + std::move(client_->TransferToServer(bool_literal)).ValueOrDie()); XlaOp bool_param = - Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); - parameter_shapes.push_back(bool_literal->shape()); + parameter_shapes.push_back(bool_literal.shape()); auto init = Tuple(&builder, params); @@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest, param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({i, i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), })) .ConsumeValueOrDie(); @@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); - const Shape original = literal->shape(); + const Shape original = literal.shape(); { // Reverse the layout present in original, and make that the layout of the // literal. @@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape_do_not_use()->mutable_layout() = + *literal.mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, literal->Get({0, 1})); + ASSERT_EQ(2, literal.Get({0, 1})); } // Use the original shape in building the computation. XlaBuilder builder(TestName()); @@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5f322b768d8620cb64a79bb8fca5fecf282f28f5..8f2c26f0eea9c7a3b33cd77e5977924c1659535a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,8 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, absl::Span dims, - int64 seed = 42); + Literal UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each @@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest(T a, T b, - absl::Span dims, - int64 seed) { +Literal PrngTest::UniformTest(T a, T b, absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -60,8 +58,8 @@ std::unique_ptr PrngTest::UniformTest(T a, T b, SetSeed(seed); auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](absl::Span, T value) { + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions())); + actual.EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { constexpr int64 count = 100; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); - result->Literal::EachCell( - [&](absl::Span, bfloat16 value) { - int64 index = static_cast((value - low) / interval); - counts[index]++; - }); + result.EachCell([&](absl::Span, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); } // Each bucket should have similar amount of counts. That is, not more than // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 @@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell( + actual.EachCell( [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { @@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { }; XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + client_->TransferToServer(param0_literal)); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto fn = build_sum_rng(builder); Map(&builder, {param0}, fn, {0}); @@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), - ShapeUtil::ElementsIn(param0_literal->shape())); - for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { - EXPECT_GE(actual->data()[i], param0_literal->data()[i]); - EXPECT_LT(actual->data()[i], - param0_literal->data()[i] + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()), + ShapeUtil::ElementsIn(param0_literal.shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) { + EXPECT_GE(actual.data()[i], param0_literal.data()[i]); + EXPECT_LT(actual.data()[i], param0_literal.data()[i] + 1.0f); } } @@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); - std::unique_ptr result1; + Literal result1; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options1)); } - std::unique_ptr result2; - std::unique_ptr result3; + Literal result2; + Literal result3; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options1)); } - std::unique_ptr result4; - std::unique_ptr result5; - std::unique_ptr result6; + Literal result4; + Literal result5; + Literal result6; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9af9ea4a2229bb6ca7c3561350f11837f5072a2c..c9096fb29b2019796c42b69de80c63b5fc7c5c3a 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = LiteralUtil::CreateR4( + Literal reduce_input = LiteralUtil::CreateR4( {{ /*i0=0*/ {/*i1=0*/ {-0.246092796, -0.179497838, -0.161181688}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 0916a07f4fa99af6cf25441fa8558a558bfa032f..26e2bfde5cdc19657640f24f31bc008d09ad7106 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); @@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // Abs doesn't affect resolution. auto abs = Abs(a); @@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 57f7fed61f6d4755b79b069fe1c2d9ced6bb8932..83997cdac21c437d460dabdbdfdb31100b1359af 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to // Span. @@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr literal_2d_; - std::unique_ptr literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 major = 0; major < 2; ++major) { @@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); @@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); @@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); @@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0(&builder, 0.0f), add, GetParam().reduce_dims); @@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = - LiteralUtil::CreateR1({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); @@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = - LiteralUtil::CreateR1(operand); + Literal input_literal = LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index a1001296a1f1071ae91ac62f3f2a428691d8837e..63491a90bf2634a53591e2ab431781f3c4237681 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), @@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); + LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({100, 1}), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), {}, ErrorSpec(0.00001)); } @@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = absl::make_unique(shape); - arg_literal->PopulateWithValue(1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + Literal arg_literal(shape); + arg_literal.PopulateWithValue(1.0f); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); @@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = absl::make_unique(result_shape); - expected->PopulateWithValue(27.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + Literal expected(result_shape); + expected.PopulateWithValue(27.0f); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { @@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -594,7 +588,7 @@ string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -614,11 +608,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); input.FillRandom(0.1f, 0.1f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(4); @@ -627,7 +620,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; if (use_bfloat16() && Product(param.window_bounds) > 128) { @@ -659,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - std::unique_ptr expected_literal = - LiteralUtil::CreateFromArray(*expected); + Literal expected_literal = LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( - input_literal->shape().element_type(), - AsInt64Slice(expected_literal->shape().dimensions()), param.layout); - ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + input_literal.shape().element_type(), + AsInt64Slice(expected_literal.shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -988,7 +980,7 @@ string R3ReduceWindowTestDataToString( param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1008,12 +1000,11 @@ TEST_P(R3ReduceWindowTest, DoIt) { Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2]); input.FillRandom(0.1f, 0.1f); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; if (use_bfloat16()) { - input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal); + input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); if (Product(param.window_bounds) > 128) { // To avoid numerical issues, force the reducer to be kMax for large bf16 // windows. @@ -1021,9 +1012,9 @@ TEST_P(R3ReduceWindowTest, DoIt) { } } - XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input"); + XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) @@ -1035,7 +1026,7 @@ TEST_P(R3ReduceWindowTest, DoIt) { /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1130,7 +1121,7 @@ string R2ReduceWindowTestDataToString( param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1147,12 +1138,11 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1162,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1178,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1332,7 +1322,7 @@ string R1ReduceWindowTestDataToString( "__pad_high_", absl::StrJoin(param.pad_high, "x"), "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1352,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + auto input_arg = + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; @@ -1365,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1384,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d8914513819415368a628eab1f482f9644dd46b1..5cf87e565bf493167f5173588e7afa3b96282488 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. - LiteralTestUtil::ExpectR0Equal(4, *literal); + LiteralTestUtil::ExpectR0Equal(4, literal); } XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { @@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{x_data.get(), y_data.get()}, @@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { .ConsumeValueOrDie(); // Expect 5. - LiteralTestUtil::ExpectR0Equal(5, *literal); + LiteralTestUtil::ExpectR0Equal(5, literal); } TEST_F(ReplayTest, MapPlusTwoOverR1) { @@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. - LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, literal); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 17d12715f60f624c35169048121ca139d78a544f..dedc95b5ae8315185a35f786af42aad53bd7ad96 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); @@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); @@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, zero_error_spec_); } } @@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}, {}); EXPECT_THAT( @@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), @@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, {1, 0}); - std::unique_ptr actual = + Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); - std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(expected); } - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Array4D input(2, 1, 1, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Array4D input(2, 1, 4, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Array4D input(5, 10, 2, 3); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); @@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { input_array.Each( [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, {2, 3, 0, 1}); - std::unique_ptr output_literal = + Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, &execution_options) @@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); - EXPECT_EQ(expected->data(), output_literal->data()); + auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); } else { - EXPECT_EQ(input_literal->data(), output_literal->data()); + EXPECT_EQ(input_literal.data(), output_literal.data()); } } @@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { @@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { @@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { @@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { @@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { @@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { @@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) - ->Relayout(input_literal->shape().layout()); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal) + .Relayout(input_literal.shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 74ded82ddfae10c21fe98ec2e250b4eaecf95222..4e55b0d7ac4453d074500f3a7fda96cb5ab52c56 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) { ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto a = AddParam(*input_literal, &builder); + auto a = AddParam(input_literal, &builder); Rev(a, spec.reversal); - std::unique_ptr expected = input_literal->CloneToUnique(); + Literal expected = input_literal.Clone(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell([&](absl::Span indices, float) { + expected.EachCell([&](absl::Span indices, float) { for (int64 i = 0; i < indices.size(); ++i) { output_indices[i] = indices[i]; } - float value = input_literal->Get(indices); + float value = input_literal.Get(indices); for (int64 dim : spec.reversal) { output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; } - expected->Set(output_indices, value); + expected.Set(output_indices, value); }); - ComputeAndCompareLiteral(&builder, *expected, {}); + ComputeAndCompareLiteral(&builder, expected, {}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index e692b8c5d5e661587bac16a2992e35f92c4c0bd9..091a5d2cacce6ac5bf986776e5ec96612d08cc75 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -38,7 +38,7 @@ namespace { class RoundTripPackedLiteralTest : public ClientLibraryTestBase { protected: // Sends the literal to the server and retrieves it back. - std::unique_ptr RoundTripToServer(const Literal& original) { + Literal RoundTripToServer(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); return client_->Transfer(*data).ConsumeValueOrDie(); @@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, actual->Get({0})); - EXPECT_EQ(24.0, actual->Get({1})); + EXPECT_EQ(42.0, actual.Get({0})); + EXPECT_EQ(24.0, actual.Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({0, 1})); - EXPECT_EQ(64.0f, actual->Get({1, 0})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({0, 1})); + EXPECT_EQ(64.0f, actual.Get({1, 0})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({1, 0})); - EXPECT_EQ(64.0f, actual->Get({0, 1})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({1, 0})); + EXPECT_EQ(64.0f, actual.Get({0, 1})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac05ba4f0df339909f3e82a28ac35253..cd5a531603b0cb6e0f48f4dcd49891cbd5428602 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(1.0), LiteralUtil::CreateR1({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 07460a7e01a5497aa6411ddb6866dddfc70f2068..1dd937a6d0656b53f9e7e0cb25acf80f0c3d59c0 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = LiteralUtil::CreateR0(value); + Literal a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), {a_data.get()}); } @@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + Literal a_literal = LiteralUtil::CreateR0(2.1f); + Literal b_literal = LiteralUtil::CreateR0(5.5f); + Literal c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); std::unique_ptr c_data = - client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + client_->TransferToServer(c_literal).ConsumeValueOrDie(); - XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); - XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); - XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c"); Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend / divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend % divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); + Literal literal = LiteralUtil::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 1858dcea61241a2aeee11592a9b09f200763b25a..d20dba028a586fa7c93c96dca03c77e3668fa644 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -62,13 +62,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { @@ -92,13 +90,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { @@ -123,13 +120,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { @@ -154,13 +149,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { @@ -185,13 +178,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { @@ -216,13 +208,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { @@ -247,13 +237,12 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { @@ -277,15 +266,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { @@ -309,15 +296,13 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { @@ -341,12 +326,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { @@ -370,13 +354,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ZeroDimBounds) { @@ -400,11 +382,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { @@ -429,12 +410,11 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { @@ -458,13 +438,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { @@ -488,13 +468,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NegativeIndex) { @@ -518,13 +498,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OneScalarIndex) { @@ -548,12 +528,12 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ScalarUpdate) { @@ -577,10 +557,10 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = LiteralUtil::CreateR0(25); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, EmptyIndices) { @@ -604,10 +584,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); - std::unique_ptr updates = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } } // namespace diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c9a58aefb4acc066c10e98aea46375523cf554d0..a40c2d7de6eceea489004f5266d060f26da5d1a8 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), - &expected_literal->shape()); + ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001), + &expected_literal.shape()); } struct R1Spec { @@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase, auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = Parameter(&builder, 0, literal->shape(), "p0"); + auto original = Parameter(&builder, 0, literal.shape(), "p0"); Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase, } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = Parameter(&builder, 0, literal->shape(), "p0"); + auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {arg.get()}); @@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 3ae31191a044458663d8b8034b3368b65ef7e771..5155f0c652c7c6dbba60c421159494fa28072090 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. -StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { +StatusOr MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; + std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr element, + Literal element, MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } @@ -131,60 +132,52 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = absl::make_unique(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); TF_CHECK_OK( - literal->Populate([&](absl::Span /*indices*/) { + literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { +Literal MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { std::vector start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -293,7 +286,7 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr> CreateLiteralForConstrainedUses( +StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector index_space; @@ -358,9 +351,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. @@ -374,34 +367,33 @@ StatusOr> CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { +StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeArguments(module, engine.get()); } -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine) { +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::vector> arguments(params.size()); + std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index a260271b1bc344fb44fedcca8a69b0b61c82c8e7..b3c8a739058475a4e51bae6ad2a98152a6532b47 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -57,8 +57,8 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random. -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true); +StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -84,14 +84,14 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random = true); +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random = true); // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine); +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 322c8ef090cf867f65cada5cb1dbae188f83bad6..181e5cbe290b0df0cf605cc4ef4b8a4945b3d367 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 3); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const float& value : key_arg.data()) { @@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const int32& value : key_arg.data()) { diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index c7eb9e2dbe0e27b7933f5861280a3401cd268c08..b34fd0f2e873214c509533f29553af914ddc984d 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(42, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(42, result.Get({})); } { @@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(false); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(7, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(7, result.Get({})); } } diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 125513ddfd16cb4e742e7d589e22b721307621ee..d6641d257a75945be94d299a1bd4b0366e3759b7 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = LiteralUtil::CreateR0(42); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR0(42); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR0Equal(42, *result); + LiteralTestUtil::ExpectR0Equal(42, result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, - *result); + result); } XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1(test_vector); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR1Equal(test_vector, *result); + LiteralTestUtil::ExpectR1Equal(test_vector, result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1U8(test_string); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_EQ(result->GetR1U8AsString(), test_string); + EXPECT_EQ(result.GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( - LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LayoutUtil::Equal(result.shape().layout(), literal.shape().layout())); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple({}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = LiteralUtil::CreateR1( + Literal literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( + Literal literal = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1( - {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) - .get(), - LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), - LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}), + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f))}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { @@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(456.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); - - auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); - auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + Literal literal1 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + Literal literal2 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(456.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-98.0f, 153.0f})}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1.shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; auto stream2 = stream_->GetOrCreateSubStream(); - std::unique_ptr result1, result2; + Literal result1, result2; // Round trip literals through device in multiple streams asynchronously. for (int i = 0; i < kIterationCount; ++i) { - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1, device_buffer1)); - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2, device_buffer2)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result1, + Literal this_result1, transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result2, + Literal this_result2, transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); result1 = std::move(this_result1); result2 = std::move(this_result2); } - EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2)); } class TransferDeviceToHostBenchmark : public TransferManagerTest { @@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); } tensorflow::testing::StopTiming(); @@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); } tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index f2b3b49015c7d74d786f63776abff1d5181fd961..619d2a388b5646c31f0a61f709a2ab3067e39c03 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar1).get(), - LiteralUtil::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar1), + LiteralUtil::CreateR0(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2(constant_matrix), + LiteralUtil::CreateR1(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), - LiteralUtil::CreateR0(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(direction), + LiteralUtil::CreateR0(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec1), LiteralUtil::CreateR1(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = absl::make_unique(sum->shape()); - ASSERT_TRUE(prod->Populate( - [&sum](absl::Span indexes) { - return sum->Get(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate([&sum](absl::Span indexes) { + return sum.Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = Literal::CreateFromShape(expected->shape()); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), *literal)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 8f80a9f3e466d73f2b718452d9a0d64a80c3b36f..4fbd7f2fb174ac899c1e3b23801986cb52db96a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper() { {-inf(), 0}}); Abs(arg); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = LiteralUtil::CreateR1( + Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper() { auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); - std::unique_ptr expected = - LiteralUtil::CreateR1({0, 0, 0, 0}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { @@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); - std::unique_ptr expected = - LiteralUtil::CreateR0({-2.6f, 0.8f}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9330b715b0ba4aca71d56307883c775..7abd8651d5ca272f9e82d797870a3bd6b1589615 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // have all reached 2.0. auto expected_data = LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = LiteralUtil::MakeTuple({expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { @@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple( + {&expected_counter, &expected_w2, &expected_w3, &expected_w1}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { @@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPredicateTupleResult) { @@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_predicate = LiteralUtil::CreateR0(true); - auto expected = LiteralUtil::MakeTuple( - {expected_counter.get(), expected_predicate.get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); + auto expected = + LiteralUtil::MakeTuple({&expected_counter, &expected_predicate}); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { @@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR0(7); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests two while nodes when the result type T is a Tuple and the second @@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. @@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({&expected_element, &expected_element}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(42))); + client_->TransferToServer(LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(1))); + client_->TransferToServer(LiteralUtil::CreateR0(1))); auto add1 = LiteralUtil::CreateR0(15); auto add2 = LiteralUtil::CreateR0(16); - auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + auto expected = LiteralUtil::MakeTuple({&add1, &add2}); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN( - auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + auto param_value, client_->TransferToServer(LiteralUtil::CreateR2( {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( @@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { XlaBuilder builder(TestName()); While(condition, body, ConstantR0(&builder, 0)); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(false))); ComputeAndCompareR0(&builder, 2, {}); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7fd42944debe38abbf6f0ca36bc5c7ecb1aeaf97..db5a824de08edeb81b5deb047507dc6158833008 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 442e66321ee732f3d9cdfe4931433bd864b7fa82..cdde88c1359416d423685f330e9cbdf77948034f 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -39,8 +39,7 @@ limitations under the License. namespace xla { -StatusOr> TextLiteralReader::ReadPath( - absl::string_view path) { +StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -57,7 +56,7 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -StatusOr> TextLiteralReader::ReadAllLines() { +StatusOr TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); string shape_string; @@ -74,9 +73,9 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape)); } - auto result = absl::make_unique(shape); + Literal result(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill); + result.PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -116,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line); } - result->Set(coordinate_values, value); + result.Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index b265640802c88847ce57e9f942f9f0859b873ae8..c40b43279f56fbd6e8ec91cc45c1f8e4cac8b5ef 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -41,7 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath(absl::string_view path); + static StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -49,7 +49,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr> ReadAllLines(); + StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 92f9b4f9f0efa2dc08287bdcbefc88f879164308..1fab4e3a08dd3d76a6efeaabe7bf8ab96892e638 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) { tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) .ok()); - std::unique_ptr literal = - TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, literal->Get({0, 0, 0})); - EXPECT_EQ(43.5, literal->Get({0, 0, 1})); - EXPECT_EQ(44.5, literal->Get({0, 0, 2})); - EXPECT_EQ(45.5, literal->Get({0, 1, 0})); - EXPECT_EQ(46.5, literal->Get({0, 1, 1})); - EXPECT_EQ(47.5, literal->Get({0, 1, 2})); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape())); + EXPECT_EQ(42.5, literal.Get({0, 0, 0})); + EXPECT_EQ(43.5, literal.Get({0, 0, 1})); + EXPECT_EQ(44.5, literal.Get({0, 0, 2})); + EXPECT_EQ(45.5, literal.Get({0, 1, 0})); + EXPECT_EQ(46.5, literal.Get({0, 1, 1})); + EXPECT_EQ(47.5, literal.Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 4ea02faffcd52065b05c0444202bd1a3d9d87ee6..5cbaf2fcc192c48092272094710ccaf5c9cf9616 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); - ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &contents)); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ba814af4769f43dbe96190c902cf6f52ca5659bb..0c41f227b31ebe1f01073785ea2a666093aefdb3 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -121,11 +121,10 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } else { // use recorded data if available for (const auto& proto : module.arguments()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - Literal::CreateFromProto(proto)); + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer data, - client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); scoped_shaped_buffer_arguments.push_back(std::move(data)); } for (const auto& argument : scoped_shaped_buffer_arguments) { @@ -161,12 +160,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. absl::optional pool; - std::unique_ptr data; + Literal data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); } auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(*data)); + TF_CHECK_OK(client->TransferToInfeed(data)); }; if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", @@ -214,9 +213,9 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "s: " << module.hlo().hlo_module().name(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + TF_ASSIGN_OR_RETURN(Literal result_literal, client->ShapedBufferToLiteral(*result)); - return std::move(*result_literal); + return result_literal; } StatusOr ParseInputFile(const string& filename, @@ -305,11 +304,11 @@ int RealMain(absl::Span args, const Options& opts) { result.ToString().c_str()); auto& snapshot = snapshots[i]; if (snapshot.has_result()) { - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index 51909190a3ef20c3df78d08796e88bdbb650609d..4f8852f8c11fb749ef851bc4faf176fcc5cb3524 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,8 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - std::unique_ptr literal = + xla::Literal literal = xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal->ToString().c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 48c837481181f6ad8f864569fd62e0e23fa02ecd..4b5c276bdf66f3dc5364aae4654b13a625b0a4f7 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -36,16 +36,16 @@ int main(int argc, char **argv) { LOG(QFATAL) << "Usage: " << argv[0] << " "; } - std::unique_ptr literal = + xla::Literal literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << *literal; - fprintf(stderr, "%s\n", literal->ToString().c_str()); - if (literal->shape().element_type() == xla::F32) { - float min = *std::min_element(literal->data().begin(), - literal->data().end()); - float max = *std::max_element(literal->data().begin(), - literal->data().end()); + LOG(INFO) << "literal: " << literal; + fprintf(stderr, "%s\n", literal.ToString().c_str()); + if (literal.shape().element_type() == xla::F32) { + float min = *std::min_element(literal.data().begin(), + literal.data().end()); + float max = *std::max_element(literal.data().begin(), + literal.data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index dd329f118172d26be3c900958f94b55b7fd6691e..73b3589dbf12341ddb3f3e819a550467a7b4d166 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -351,6 +351,7 @@ message DeviceAssignmentProto { message LiteralProto { Shape shape = 1; repeated bool preds = 2; + bytes s8s = 15; bytes u8s = 3; repeated int32 s32s = 4; repeated int64 s64s = 5; @@ -364,7 +365,7 @@ message LiteralProto { bytes f16s = 11; bytes bf16s = 13; repeated int64 sparse_indices = 14; - // Next = 15 + // Next = 16 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 478c9663a7641ba2bf22e9119212ee8ef8947d4f..54b06558adcd8ef1f8f1bee52d210d558801afea 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -49,7 +49,7 @@ class XRTStateHelpers { // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an // OpKernel::Compute method. static Status MakeLiteral(const xla::LiteralProto& proto, - std::unique_ptr* literal) { + xla::Literal* literal) { TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); return Status::OK(); } @@ -173,7 +173,7 @@ class XRTAllocateOp : public OpKernel { errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); @@ -189,7 +189,7 @@ class XRTAllocateOp : public OpKernel { XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - *literal, device_ref.backend(), + literal, device_ref.backend(), device_ref.device_ordinal(), &allocation)); // Intern takes ownership of our reference to allocation. @@ -381,11 +381,11 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); - xla::LiteralProto literal_proto = literal->ToProto(); + xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); literal_proto.SerializeToString(&output.scalar()()); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 5b8516bf1dceb4ffa37a8fb52fb287281a661e9d..2952feb16a8a60aecf16be87c9b800d314c4af58 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -52,44 +52,44 @@ string DeviceFromFlag() { xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - return tuple->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + return tuple.ToProto(); } xla::LiteralProto ScalarLiteral() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); - return scalar->ToProto(); + return scalar.ToProto(); } xla::LiteralProto NestedTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); auto scalar = xla::LiteralUtil::CreateR0(12.0f); - auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); - return nested->ToProto(); + auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar}); + return nested.ToProto(); } xla::LiteralProto MakeTuple0() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()}); - auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()}); - return nested1->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple}); + auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0}); + return nested1.ToProto(); } -xla::LiteralProto FloatVector(gtl::ArraySlice v) { +xla::LiteralProto FloatVector(absl::Span v) { auto array = xla::LiteralUtil::CreateR1(v); - return array->ToProto(); + return array.ToProto(); } bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = *l_a == *l_b; + bool equal = l_a == l_b; if (!equal) { LOG(INFO) << "LiteralProtos don't match " << a.DebugString() << " != " << b.DebugString(); @@ -100,7 +100,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, bool CompareLiteralToLiteralProto(const xla::Literal& a, const xla::LiteralProto& b) { auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = a == *l_b; + bool equal = a == l_b; if (!equal) { LOG(INFO) << "Literal and LiteralProto don't match " << a.ToProto().DebugString() << " != " << b.DebugString(); @@ -211,7 +211,7 @@ TEST(RawApiTest, SubBuffer) { TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs)); auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie(); - auto base_elements = base_literal->DecomposeTuple(); + auto base_elements = base_literal.DecomposeTuple(); auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); xla::LiteralProto response_0; EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); @@ -343,7 +343,7 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } TEST(RawApiTest, CompileAndExecuteReturnTuple) { @@ -392,8 +392,8 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); - auto expected = xla::LiteralUtil::MakeTuple({sum.get()}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + auto expected = xla::LiteralUtil::MakeTuple({&sum}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } } // namespace diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 2c3b07da58b5d59632bf813256f9a14fa5c06413..d05a1e7dcbff440e0daf03bd25535c26d82b6a0b 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -174,7 +174,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal) { + xla::Literal* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 42705688ddfeb21aa734cccfce36c8d11d0d60a9..73b5584e38f781343fe6793af7ad28232fbfc184 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -135,7 +135,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal); + xla::Literal* literal); // True if none of the buffers in the allocation are aliased by any other live // handle. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 798f499870095043b77389d0f39306bd4d309259..d98a24994cbf080184fe47111a718f31b7a64f0b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -166,7 +166,9 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_kernels", ], "//conditions:default": [], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + ]), ) cc_library( @@ -203,5 +205,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_ops_op_lib", ], "//conditions:default": [], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + ]), ) diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index ad700ac4a0342e2a7bc07a6ecf6710cea892e296..e37ad7a7581666e8207d5d35e197be3b3576a24d 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -21,11 +21,9 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + # This module is kept for backward compatibility only. To depend on AutoGraph, + # use //third_party/tensorflow/python/autograph instead. deps = [ - "//tensorflow/contrib/autograph/impl", - "//tensorflow/contrib/autograph/lang", - "//tensorflow/contrib/autograph/pyct", - "//tensorflow/contrib/autograph/utils", - "//tensorflow/python:util", + "//tensorflow/python/autograph", ], ) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index cc54da4daa9a5bb4e64145963ffec63021d08876..6ea2db72c411f2f19a06ff108d6b63fc3bde352b 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,5 +1,12 @@ # AutoGraph +**NOTE: As tensorflow.contrib is being +[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is +moving into TensorFlow core. + +The new code location is `tensorflow/python/autograph`. +** + IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 26e7a4a4d38e264486c981e6fc4c547bcc53b302..137bc59202b26c1c224fec4c2fca2dec83db13a5 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -12,57 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Autograph compiles Python code into equivalent TensorFlow code. +"""This is the legacy module for AutoGraph, kept for backward compatibility. -Equivalent here means that they have the same effect when executed. +New users should instead use `tensorflow.python.autograph`. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Bring only the relevant symbols to the top level. -from tensorflow.contrib.autograph import operators -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.core.errors import GraphConstructionError -from tensorflow.contrib.autograph.core.errors import TfRuntimeError -from tensorflow.contrib.autograph.core.errors import improved_errors -from tensorflow.contrib.autograph.impl.api import RunMode -from tensorflow.contrib.autograph.impl.api import convert -from tensorflow.contrib.autograph.impl.api import converted_call -from tensorflow.contrib.autograph.impl.api import do_not_convert -from tensorflow.contrib.autograph.impl.api import to_code -from tensorflow.contrib.autograph.impl.api import to_graph -from tensorflow.contrib.autograph.lang.directives import set_element_type -from tensorflow.contrib.autograph.lang.directives import set_loop_options -from tensorflow.contrib.autograph.lang.special_functions import stack -from tensorflow.contrib.autograph.lang.special_functions import tensor_list -from tensorflow.contrib.autograph.pyct.transformer import AutographParseError -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - # Main API - 'RunMode', - 'convert', - 'converted_call', - 'do_not_convert', - 'to_code', - 'to_graph', - # Overloaded operators - 'operators', - # Errors - 'improved_errors', - 'GraphConstructionError', - 'TfRuntimeError', - # Python language "extensions" - 'set_element_type', - 'set_loop_options', - 'stack', - 'tensor_list', - # Exceptions - 'AutographParseError', - # Utilities: to be removed - 'utils', -] - -remove_undocumented(__name__, _allowed_symbols) +from tensorflow.python.autograph import * # pylint:disable=wildcard-import diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 1375fddf2bea1a8f856c35d756c38a8beb14a53f..606da663dc2e43688bc42bf6e33a48cd680f54e1 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -296,8 +296,9 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel { int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); QuantileStreamResource* streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, @@ -709,8 +710,9 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel { &buckets_list, stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); QuantileStreamResource* streams_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, &streams_resource)); diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 3b28ed77f325b3f8b09fe6b9d2776eff82ff53a7..51e0c2e431acbea727bc0b2149557d0e30c8c432 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -862,6 +862,15 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(state->feature_column_group_id()); + CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id) + << "Unexpected feature ID selected. " + << "Start feature ID: [" << start_index << "] " + << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1) + << "\nBest feature ID: [" << best_feature_idx << "] " + << feature_ids(best_feature_idx, 0) << ", " + << feature_ids(best_feature_idx, 1) + << "\nPartition IDS: " << partition_ids(start_index) << " " + << partition_ids(best_feature_idx); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index 90a0655201f8cb8df6fc6417cb51216dec91b4d7..e446c411a8d5075563b8f8b912b29df310e16c8c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -448,8 +448,9 @@ class StatsAccumulatorScalarAddOp : public OpKernel { stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); StatsAccumulatorScalarResource* accumulator_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, @@ -512,8 +513,9 @@ class StatsAccumulatorTensorAddOp : public OpKernel { stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); StatsAccumulatorTensorResource* accumulator_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 35d727482bf631f2fe14e02c1ec4b75a763e8615..4da25298cb82093ac501997cc21c48265df06860 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -_BIAS_FEATURE_ID = -1 +_BIAS_FEATURE_ID = int(dtypes.int64.min) class EqualitySplitHandler(base_split_handler.BaseSplitHandler): diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index d9f03c3840f8edd88174be4e97aaaf7d0efd220b..94ea7bc2eb7b098a0628683167510bf4e3c2426e 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gains[0], 0.00001) def testGenerateFeatureSplitCandidatesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testGenerateFeatureSplitCandidatesMulticlass(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2]) @@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] @@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 5532bd026ab695d166bc2e2872ecc551920978d5..74b0ea6989c65e83e7a466107d624712a0e72d1b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class DenseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -183,7 +183,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testObliviousFeatureSplitGeneration(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 1 | 3 | @@ -320,7 +320,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(2, oblivious_split_info.children_parent_id[1]) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -458,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -546,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -633,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -708,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -842,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -951,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1074,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1207,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1302,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1397,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1475,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. values = array_ops.constant([], dtype=dtypes.float32) @@ -1545,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmptyBuckets(self): """Test that reproduces the case when quantile buckets were empty.""" - with self.test_session() as sess: + with self.cached_session() as sess: sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. @@ -1638,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testDegenerativeCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same # situation is when all examples have the same values. This case was # causing before a failure. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 4278a30ba9d35bc4e57364b63777c01a4508223d..46dfbdefeb00ffa075f7e7b6835b73eb258443d2 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -331,7 +331,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([[], []], dropout_info.eval()) def testObliviousEnsemble(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Bias tree. tree1 = tree_ensemble_config.trees.add() @@ -1399,7 +1399,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([0, 0], result.eval()) def testObliviousTreeNonFinalized(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Depth 3 tree. tree1 = tree_ensemble_config.trees.add() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index b3e4c2e5f7a907892d66ad4181eb6ed8589bab6e..86fd5770a033a15df5788d3f74563c82f660371c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -411,7 +411,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEmptyEnsembleObliviousCase(self): """Test growing an empty ensemble in the oblivious case.""" - with self.test_session() as session: + with self.cached_session() as session: # Create empty ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1620,7 +1620,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleTreeLayerByLayerObliviousCase(self): """Test growing an existing ensemble with the last tree not finalized.""" - with self.test_session() as session: + with self.cached_session() as session: # Create existing ensemble with one root split tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -1810,7 +1810,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesMiddleCase(self): """Test case: The middle existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ @@ -2071,7 +2071,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesBorderCase(self): """Test case: The first and last existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 150d734db6cdd8023ab6d91a49872f657bcdbdea..94b7f4f867655bf7fdf94e8488eeae7088c41622 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -37,6 +37,7 @@ Checkpoint management: Saving and restoring Python state: @@NumpyState +@@PythonStateWrapper """ from __future__ import absolute_import @@ -45,6 +46,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.python_state import NumpyState +from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 9b11035b6d277851ea0a0071062bf5cf6b6b2185..302d5cfb79a08b6adf52ebd44533152c5454eadc 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import functools +import six import numpy @@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase): # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making # ndarrays checkpointable natively and using standard checkpointable list # tracking. - if isinstance(value, numpy.ndarray): + if isinstance(value, (numpy.ndarray, numpy.generic)): try: existing = super(NumpyState, self).__getattribute__(name) existing.array = value @@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase): super(NumpyState, self).__setattr__(name, value) -class _NumpyWrapper(base.CheckpointableBase): +@six.add_metaclass(abc.ABCMeta) +class PythonStateWrapper(base.CheckpointableBase): + """Wraps a Python object for storage in an object-based checkpoint.""" + + @abc.abstractmethod + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the object.""" + + @abc.abstractmethod + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the object.""" + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "py_state": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } + + +class _NumpyWrapper(PythonStateWrapper): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase): self.array = array def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the array.""" + """Callback to serialize the array.""" string_file = BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) @@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase): return serialized def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the array.""" + """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "array": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 0439a4755e36fc3be6e065d18d3e835feda8aab3..45494351ff4e6c8c75634d8563c3fb63c6089036 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase): save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) + save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. + save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = python_state.NumpyState() @@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase): loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() @@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase): loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(4, load_state.c) def testNoGraphPollution(self): graph = ops.Graph() diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py index 493b3c6f1b5e7a7a7dc1dd4f48d2f54c1d284098..11e177cd0c81f99bd6e00eac4de90a46fb9f64f0 100644 --- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py @@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase): def _ReadAndCheckRowsUsingFeatures(self, num_rows): self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: feature_configs = { "int64_col": parsing_ops.FixedLenFeature( @@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase): num_rows = 10 self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: reader = cloud.BigQueryReader( project_id=_PROJECT, dataset_id=_DATASET, diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py index 9b6c056d6c8adfa50b95aefb8e9740631327a572..4f2ecbcb170b56ab276ec37bbaa3db2485d58f49 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase): def testSetBlockCache(self): cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) - with self.test_session() as sess: + with self.cached_session() as sess: gcs_config_ops.configure_gcs(sess, block_cache=cfg) def testConfigureGcsHook(self): @@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase): 'type': 'authorized_user'} hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) hook.begin() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None hook.after_create_session(sess, None) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 0b79f718d4823a987e02804f59a432ee46d0ada3..789dab81ed848851f6597ec8dfae3d3455e84f86 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -1,6 +1,10 @@ TensorFlow CMake build ====================== +CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all +platforms. For details, see the +[TensorFlow install guide](https://www.tensorflow.org/install/). + This directory contains CMake files for building TensorFlow on Microsoft Windows. [CMake](https://cmake.org) is a cross-platform tool that can generate build scripts for multiple build systems, including Microsoft diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index ad2af01bc002555ce48f8b9bfb7d8d724a1a7dc8..1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== include (ExternalProject) +include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) @@ -35,7 +36,7 @@ if(WIN32) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a) endif() set(png_HEADERS diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py index 9b4bf6271009161c4c449cd9c3cdab9fba90aa59..3e25079e02eb22cb8796cce1a49a3041bed58415 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase): multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) - with self.test_session() as session: + with self.cached_session() as session: projected_multipliers1 = session.run( external_regret_optimizer._project_multipliers_wrt_euclidean_norm( multipliers1, 1.0)) @@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase): ] multipliers = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(multipliers) < len(expected_multipliers): multipliers.append(session.run(optimizer.lagrange_multipliers)) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py index 34c4543dca97e12c8335e4c90b849820edaefa81..df0eced631718995fc3219657db6813da7375cba 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase): matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: eigenvector1 = session.run( swap_regret_optimizer._maximal_eigenvector_power_method( standard_ops.constant(matrix1))) @@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( matrix)) @@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( standard_ops.exp( swap_regret_optimizer. @@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) @@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 8cfe14205927bf7763cf36fa31012ab10fce995c..556d73184022dcc23add29114d717ab17302f8d4 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -61,7 +61,7 @@ class CrfTest(test.TestCase): for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, inputs_list, tag_indices_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), @@ -96,7 +96,7 @@ class CrfTest(test.TestCase): ] for sequence_lengths, inputs, tag_bitmap in zip( sequence_lengths_list, inputs_list, tag_bitmap_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_multitag_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), @@ -124,7 +124,7 @@ class CrfTest(test.TestCase): for dtype in (np.int32, np.int64): tag_indices = np.array([1, 2, 1, 0], dtype=dtype) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: unary_score = crf.crf_unary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -140,7 +140,7 @@ class CrfTest(test.TestCase): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: binary_score = crf.crf_binary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -176,7 +176,7 @@ class CrfTest(test.TestCase): tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. @@ -206,7 +206,7 @@ class CrfTest(test.TestCase): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], @@ -226,7 +226,7 @@ class CrfTest(test.TestCase): sequence_lengths = np.array(3, dtype=np.int32) num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_log_likelihoods = [] # Make sure all probabilities sum to 1. @@ -254,7 +254,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -310,7 +310,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -351,7 +351,7 @@ class CrfTest(test.TestCase): """ Test that crf_decode works when sequence_length contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5e6c1520a2fc1c21678625c9d4aae04164b198f6..c378b1ce8d953ec2f7d1ce7061286a8da437906d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@CheckpointInputPipelineHook @@CsvDataset @@LMDBDataset +@@Optional @@RandomDataset @@Reducer @@SqlDataset @@ -38,7 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset - +@@get_next_as_optional @@get_single_element @@group_by_reducer @@group_by_window @@ -46,7 +47,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator - @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@sloppy_interleave @@unbatch @@unique + +@@AUTOTUNE """ from __future__ import absolute_import @@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE + from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device @@ -107,10 +113,9 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter +from tensorflow.python.data.ops.iterator_ops import get_next_as_optional +from tensorflow.python.data.ops.optional_ops import Optional # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) - -# A constant that can be used to enable auto-tuning. -AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 74107d524259ec745b860839299f278b4d257171..21ec50fb6b8a5bbff6d7fe85fb949e127ceab8ed 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -49,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, errors::InvalidArgument( "There should only be 1 default per field but field ", i, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index ae104d55bd813fdbc9829ccbc274612a112c8e1d..ad410e17feb9de825aa3af07d4269161121a621a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset") TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); // `record_defaults` must be lists of scalars for (size_t i = 8; i < c->num_inputs(); ++i) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + shape_inference::ShapeHandle v; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { + return errors::InvalidArgument( + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); + } } return shape_inference::ScalarShape(c); }); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b9320e5fefce093a68c9e826e7ebe63c77294786..ba202839b2f83b61256686b955c51bc0ae2cdace 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -72,12 +72,13 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) @@ -276,6 +277,7 @@ py_test( "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", @@ -285,21 +287,6 @@ py_test( ], ) -py_test( - name = "optimize_dataset_op_test", - size = "small", - srcs = ["optimize_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - py_test( name = "parsing_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 67242fecfe3436edd01ff580279bddab9a12910f..8e368bf2bc5060e1655dd24b1d285b0ee80e094d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an input tensor of incompatible rank. sess.run(init_op, feed_dict={input_tensor: [[1]]}) with self.assertRaisesRegexp(errors.InvalidArgumentError, @@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i,) * 3, sess.run(op)) @@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) @@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): st_row = sess.run(next_element) self.assertEqual([i], st_row.indices) @@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): dense_elem, st_row = sess.run(next_element) self.assertEqual(i, dense_elem) @@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) @@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) @@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Mismatch in the 0th dimension. sess.run( iterator.initializer, @@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Batch of a finite input, where the batch_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 28, batch_size: 14}) @@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) if not drop_remainder: @@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_one_shot_iterator()) self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) self.assertAllEqual([[64], [81]], sess.run(next_element)) @@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) .make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) @@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "number of elements does not match"): @@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(3): sess.run(get_next) @@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=10)).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) if threshold % 10 != 0: @@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) @@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 2022c1f2bdd09cdf43a993b3666335ce468a40ba..48971f2ccc4317d2bf591ae1e07cd6d5baf7b965 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for expected in values: got = sess.run(get_next) self.assertEqual(got, expected) @@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase): self.assertIs(None, dataset.output_shapes[1].ndims) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual([0] * (2**i), x) self.assertAllEqual(np.array(1, ndmin=i), y) @@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase): (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual(x, np.asarray([x for x in range(10)])) self.assertEqual(y, 45) @@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # The input is infinite, so this test demonstrates that: # 1. We produce output without having to consume the entire input, @@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) @@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -376,7 +376,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) which_bucket, bucketed_values = sess.run(get_next) @@ -411,7 +411,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches (one containing even values, one containing odds) @@ -482,7 +482,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches ([0, 2, ...] and [64, 66, ...]) @@ -515,7 +515,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): batches = 0 @@ -531,6 +531,45 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +def _element_length_fn(x, y=None): + del y + return array_ops.shape(x)[0] + + +def _to_sparse_tensor(record): + return sparse_tensor.SparseTensor(**record) + + +def _format_record(array, sparse): + if sparse: + return { + "values": array, + "indices": [[i] for i in range(len(array))], + "dense_shape": (len(array),) + } + return array + + +def _get_record_type(sparse): + if sparse: + return { + "values": dtypes.int64, + "indices": dtypes.int64, + "dense_shape": dtypes.int64 + } + return dtypes.int32 + + +def _get_record_shape(sparse): + if sparse: + return { + "values": tensor_shape.TensorShape([None,]), + "indices": tensor_shape.TensorShape([None, 1]), + "dense_shape": tensor_shape.TensorShape([1,]) + } + return tensor_shape.TensorShape([None]) + + class BucketBySequenceLength(test.TestCase): def testBucket(self): @@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes, lengths): - for _ in range(batch_size): - elements.append([1] * length) - random.shuffle(elements) - for el in elements: - yield (el,) - - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.test_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) + def build_dataset(sparse): + def _generator(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + record_len = length - 1 + for _ in range(batch_size): + elements.append([1] * record_len) + record_len = length + random.shuffle(elements) + for el in elements: + yield (_format_record(el, sparse),) + dataset = dataset_ops.Dataset.from_generator( + _generator, + (_get_record_type(sparse),), + (_get_record_shape(sparse),)) + if sparse: + dataset = dataset.map(lambda x: (_to_sparse_tensor(x),)) + return dataset + + def _test_bucket_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply( + grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + batch_sizes, + no_padding=no_padding)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + shape = batch.dense_shape if no_padding else batch.shape + batch_size = shape[0] + length = shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + sum_check = batch.values.sum() if no_padding else batch.sum() + self.assertEqual(sum_check, batch_size * length - 1) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + for no_padding in (True, False): + _test_bucket_by_padding(no_padding) def testPadToBoundary(self): @@ -600,7 +658,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(3): batches.append(sess.run(batch)) @@ -637,7 +695,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(5): batches.append(sess.run(batch)) @@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase): def testTupleElements(self): - def elements_gen(): - text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] - label = [1, 2, 1, 2] - for x, y in zip(text, label): - yield (x, y) - - def element_length_fn(x, y): - del y - return array_ops.shape(x)[0] - - dataset = dataset_ops.Dataset.from_generator( - generator=elements_gen, - output_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([])), - output_types=(dtypes.int32, dtypes.int32)) + def build_dataset(sparse): + def _generator(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (_format_record(x, sparse), y) + dataset = dataset_ops.Dataset.from_generator( + generator=_generator, + output_types=(_get_record_type(sparse), dtypes.int32), + output_shapes=(_get_record_shape(sparse), + tensor_shape.TensorShape([]))) + if sparse: + dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) + return dataset + + def _test_tuple_elements_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=_element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8], + no_padding=no_padding)) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + for no_padding in (True, False): + _test_tuple_elements_by_padding(no_padding) + + def testBucketSparse(self): + """Tests bucketing of sparse tensors (case where `no_padding` == True). + + Test runs on following dataset: + [ + [0], + [0, 1], + [0, 1, 2] + ... + [0, ..., max_len - 1] + ] + Sequences are bucketed by length and batched with + `batch_size` < `bucket_size`. + """ + + min_len = 0 + max_len = 100 + batch_size = 7 + bucket_size = 10 + + def _build_dataset(): + input_data = [range(i+1) for i in range(min_len, max_len)] + def generator_fn(): + for record in input_data: + yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( + generator=generator_fn, + output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) + return dataset + + def _compute_expected_batches(): + """Computes expected batch outputs and stores in a set.""" + all_expected_sparse_tensors = set() + for bucket_start_len in range(min_len, max_len, bucket_size): + for batch_offset in range(0, bucket_size, batch_size): + batch_start_len = bucket_start_len + batch_offset + batch_end_len = min(batch_start_len + batch_size, + bucket_start_len + bucket_size) + expected_indices = [] + expected_values = [] + for length in range(batch_start_len, batch_end_len): + for val in range(length + 1): + expected_indices.append((length - batch_start_len, val)) + expected_values.append(val) + expected_sprs_tensor = (tuple(expected_indices), + tuple(expected_values)) + all_expected_sparse_tensors.add(expected_sprs_tensor) + return all_expected_sparse_tensors + + def _compute_batches(dataset): + """Computes actual batch outputs of dataset and stores in a set.""" + batch = dataset.make_one_shot_iterator().get_next() + all_sparse_tensors = set() + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + while True: + output = sess.run(batch) + sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), + tuple(output.values)) + all_sparse_tensors.add(sprs_tensor) + return all_sparse_tensors + + dataset = _build_dataset() + boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8])) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) + _element_length_fn, + boundaries, + [batch_size] * (len(boundaries) + 1), + no_padding=True)) + batches = _compute_batches(dataset) + expected_batches = _compute_expected_batches() + self.assertEqual(batches, expected_batches) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 63bffd023f0e2672f41d36e27e31c9a9b26be77c..f8e74e4583df5b4e2cdd73c94361486680cee3f4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class CsvDatasetOpTest(test.TestCase): - def _assert_datasets_equal(self, g, ds1, ds2): + def _get_next(self, dataset): + # Returns a no argument function whose result is fed to self.evaluate to + # yield the next element + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + get_next = it.get_next() + return lambda: get_next + + def _assert_datasets_equal(self, ds1, ds2): assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' '%s') % (ds1.output_shapes, ds2.output_shapes) assert ds1.output_types == ds2.output_types assert ds1.output_classes == ds2.output_classes - next1 = ds1.make_one_shot_iterator().get_next() - next2 = ds2.make_one_shot_iterator().get_next() - with self.session(graph=g) as sess: - # Run through datasets and check that outputs match, or errors match. - while True: - try: - op1 = sess.run(next1) - except (errors.OutOfRangeError, ValueError) as e: - # If op1 throws an exception, check that op2 throws same exception. - with self.assertRaises(type(e)): - sess.run(next2) - break - op2 = sess.run(next2) - self.assertAllEqual(op1, op2) + next1 = self._get_next(ds1) + next2 = self._get_next(ds2) + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = self.evaluate(next1()) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + self.assertAllEqual(op1, op2) def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase): def _test_by_comparison(self, inputs, **kwargs): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" - with ops.Graph().as_default() as g: - dataset_actual, dataset_expected = self._make_test_datasets( - inputs, **kwargs) - self._assert_datasets_equal(g, dataset_actual, dataset_expected) + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(dataset_actual, dataset_expected) def _verify_output_or_err(self, - sess, dataset, expected_output=None, expected_err_re=None): - nxt = dataset.make_one_shot_iterator().get_next() if expected_err_re is None: # Verify that output is expected, without errors + nxt = self._get_next(dataset) expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] for value in expected_output: - op = sess.run(nxt) + op = self.evaluate(nxt()) self.assertAllEqual(op, value) with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) + self.evaluate(nxt()) else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): + nxt = self._get_next(dataset) while True: try: - sess.run(nxt) + self.evaluate(nxt()) except errors.OutOfRangeError: break @@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase): # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(sess, dataset, expected_output, - expected_err_re) + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 @@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): record_defaults = [['']] * 3 @@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) - with ops.Graph().as_default() as g: - self._assert_datasets_equal(g, - ds_actual.repeat(5).prefetch(1), - ds_expected.repeat(5).prefetch(1)) + self._assert_datasets_equal( + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields @@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase): ]] file_path = self._setup_files(data) - with ops.Graph().as_default() as g: - ds = readers.make_csv_dataset( - file_path, batch_size=1, shuffle=False, num_epochs=1) - next_batch = ds.make_one_shot_iterator().get_next() + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + nxt = self._get_next(ds) - with self.session(graph=g) as sess: - result = list(sess.run(next_batch).values()) + result = list(self.evaluate(nxt()).values()) self.assertEqual(result, sorted(result)) @@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase): compression_type='ZLIB', record_defaults=record_defaults) + def testCsvDataset_withScalarDefaults(self): + record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_with2DDefaults(self): + record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + + if context.executing_eagerly(): + err_spec = errors.InvalidArgumentError, ( + 'Each record default should be at ' + 'most rank 1.') + else: + err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' + + with self.assertRaisesWithPredicateMatch(*err_spec): + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 9020a499c4a5c35202a6f776d8795186b9c86e20..eb110324d12b47fc36bc0927ad8dc94e6892dc33 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(100): for i in range(10): @@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: freqs = np.zeros([num_datasets]) for _ in range(num_samples): freqs[sess.run(next_element)] += 1 @@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index e6883d53e02c0f96d966a52abfe2f9b4118f2e12..f3968cdc15a5a34af0946c5c447ce35cdfa3e00d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) - with self.test_session() as sess: + with self.cached_session() as sess: if error is None: dense_val, sparse_val = sess.run( element, feed_dict={ @@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(stop_t) element = get_single_element.reduce_dataset(dataset, sum_reducer) - with self.test_session() as sess: + with self.cached_session() as sess: value = sess.run(element, feed_dict={stop_t: stop}) self.assertEqual(stop * (stop - 1) / 2, value) diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index db2ab815eeebb77c159ca8c7d0d9920f2bdcdabd..9c508d686dd44d04444a34c703ab54f3b97eeced 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase): get_op = gen_dataset_ops.indexed_dataset_get( handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialize) self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) def testIdentityIndexedDataset(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) materialized = ds.materialize() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialized.initializer) placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) for i in range(16): @@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase): ds = indexed_dataset_ops.IdentityIndexedDataset(16) itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(itr.initializer) for i in range(16): output = sess.run(n) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 7a3215f6ccfa807e8930ac8561587e474da61195..b9e74dfddb1b238ab75928af17d545d2b6a3c033 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) def _testEmptyInput(self, sloppy=False): - with self.test_session() as sess: + with self.cached_session() as sess: # Empty input. self._clear_coordination_events() sess.run( @@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) def testBlockLengthWithContentionSloppy(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: output_values = [] for _ in range(30): output_values.append(sess.run(iterator.get_next())) @@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for j in range(2): @@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(get_next) def testErrorsInOutputFn(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): next_element = iterator.get_next() results = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): elements = [] sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 7bc582ebaa50c7418e7624a1a389f002f2cea395..1cc5ddc9a2e1eff4473c19bc397d656e7e0b90ed 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_repeats): # Dataset is repeated. for i in range(10): # 10 records. diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dc9d56dd53cc077c14eda58a22d7449c05bddec1..e8519381d69427f4c9a3ef5cefa527c368251f2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # All of the files are present. sess.run(init_op) for filename in filenames: @@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark): end = time.time() chained_deltas.append(end - start) - fused_dataset = dataset = dataset.apply( + fused_dataset = dataset.apply( batching.map_and_batch( math_ops.matmul, num_parallel_calls=num_calls, diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 61567bc8d76d13240e5ada4f5e25d8c98d883136..83b723710ca1d37a8d2b1e297321b59dcaa17ba6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -207,6 +208,31 @@ class MapDefunTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0}) + def _assert_op_cancelled(self, sess, map_defun_op): + with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"): + sess.run(map_defun_op) + + def testMapDefunWithParentCancellation(self): + # Checks that a cancellation of the parent graph is threaded through to + # MapDefunOp correctly. + @function.Defun(dtypes.int32) + def simple_fn(x): + del x + queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) + # Blocking + return queue.dequeue_many(5) + + c = constant_op.constant([1, 2, 3, 4, 5]) + map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] + + with self.test_session() as sess: + thread = self.checkedThread( + self._assert_op_cancelled, args=(sess, map_defun_op)) + thread.start() + time.sleep(0.1) + sess.close() + thread.join() + class MapDefunBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b299e0736fb29d0936680e5905172b0fa95ac586..7e9ea68047a076d368cf98960f4754b29abb074e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -6,6 +6,34 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") +py_test( + name = "assert_next_dataset_op_test", + size = "medium", + srcs = ["assert_next_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "map_vectorization_test", size = "small", @@ -46,16 +74,34 @@ py_test( ) py_test( - name = "latency_all_edges_test", + name = "model_dataset_op_test", + size = "medium", + srcs = ["model_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "optonly", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_op_test", size = "small", - srcs = ["latency_all_edges_test.py"], + srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7b50b902cf5965bfdb586c5c9fce68ba5d9cd6 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class AssertNextDatasetTest(test.TestCase): + + def testAssertNext(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertNextInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertNextShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py index 1850b6921af0aae8d26fbdfd165fd0e087134e6d..db380c02a9191bec53d5e32565d47a52cbdd44b1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py @@ -40,7 +40,7 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): get_next = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertEqual(1 * 1, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 6a7ef877f9a75af565239a4f498da3558863fc35..dde115925ee484edb88ad81b21595c3d668be84c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -74,7 +74,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(5): result = sess.run(get_next) r = x @@ -131,7 +131,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): def _testMapAndFilter(self, dataset, function, predicate): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(10): r = function(x) if isinstance(r, tuple): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a87d3e90550da8485b4f9acd941c836d7b62951 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -0,0 +1,177 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ModelDatasetTest(test.TestCase): + + def testModelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(100): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul, num_parallel_calls=56) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelMapAndBatch(self): + batch_size = 16 + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, num_parallel_calls=28, batch_size=batch_size)) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(10): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelInterleave(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=56, num_parallel_calls=56) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelNested(self): + k = 1024 * 1024 + a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1)) + b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1)) + c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1)) + dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat() + + def f1(a, b, c): + x, y = a + return math_ops.matmul(x, y), b, c + + def f2(a, b, c): + x, y = b + return a, math_ops.matmul(x, y), c + + def f3(a, b, c): + x, y = c + return a, b, math_ops.matmul(x, y) + + dataset = dataset.map(f1, num_parallel_calls=32) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f2, num_parallel_calls=16) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f3, num_parallel_calls=10) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next) + for _ in range(100): + start = time.time() + sess.run(get_next) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py similarity index 75% rename from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index 089717156c545a0ea9262c4380ab2c0fd088e209..909da5aee0ad8bce0b5b18facbcbb684dd334abf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import optimization @@ -29,41 +28,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): - - def testAssertSuffix(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - - def testAssertSuffixInvalid(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead."): - sess.run(get_next) - - def testAssertSuffixShort(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted next 2 transformations but encountered only 1."): - sess.run(get_next) +class OptimizeDatasetTest(test.TestCase): def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index f6c4a984b8608b408bc1b1bb4a712ef1c3792696..c4623bca73228b76802ed40b18eb49662f6f7d34 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase): expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 361fe0dd39bb3f855c3b0b11281a9909fd601232..0166ba0d44ef473ac54ee4f67078c1a51fddacf3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): destroy_op = resource_variable_ops.destroy_resource_op( buffer_resource_handle, ignore_lookup_error=True) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([b"a"], sess.run(prefetch_op)) self.assertEqual([b"b"], sess.run(prefetch_op)) self.assertEqual([b"c"], sess.run(prefetch_op)) @@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase): iterator = back_to_cpu_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase): elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() - with self.test_session() as sess: + with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 592642da0cfd84e50cb20d9b2e534411faf927e8..db8fe6aa1b29c5c3f872e580491d978f03360fe4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase): self.assertEqual([tensor_shape.TensorShape([])] * 3, [t.shape for t in get_next[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) @@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase): .make_one_shot_iterator()) negative_get_next = negative_iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(get_next)) self.assertEqual(3 + 4, sess.run(get_next)) self.assertEqual(3 + 2 * 4, sess.run(get_next)) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index fd00cdc5c61cb0a6bbee87963ed4097a236507d3..ed75b27a4493f9ebb9db34c4e656d394236ae08e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -116,7 +116,7 @@ class ReadBatchFeaturesTest( init_op = iterator.initializer next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index c5cfddb72b56a1bcffc80c0dd34994def3ee45cd..16b1441baab925ed5b6eee4193203690d1552f03 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): class_func=lambda c, _: c, seed=27)).make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] while len(returned) < 4000: returned.append(sess.run(get_next)) @@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 42cada0b97bcd9ab755297e8b1f0667766f7999e..dde678bd544fc2eaba36f91491fc64e4c7910756 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase): start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase): make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): (longer_vector_val, larger_rank_val), _ = sess.run(next_element) self.assertAllEqual([0] * (2**i), longer_vector_val) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 077abd6b30eafe857d27d84e533b15e4e98134e6..440e48db3095fe7006d510f7db80ad5327284659 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase): def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): get_next = ds_fn().make_one_shot_iterator().get_next() outputs = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(num_outputs): outputs.append(sess.run(get_next)) if verify_exhausted: diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 6b3e8e9f6e950cd8d128d22f84f139aee71aa746..90d18dca2aa727ea51d636cb971f48b50bc0c663 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Slide: 1st batch. actual = sess.run(get_next) @@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 2c2cfbebff5d3eba00f120467102b4185d81ab24..52823d3fcace841ff0a68b8036c4e357f7c3c7b4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string), 2) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): # Run twice to verify statelessness of db operations. sess.run( init_op, @@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetJoinQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetNullTerminator(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetReuseSqlDataset(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadEmptyResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidDriverName(self): init_op = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string))[0] - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidColumnName(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfQueryWithSyntaxError(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfInsertQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int8` tensor. def testReadResultSetInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt8NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int8` tensor. def testReadResultSetInt8MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int16` tensor. def testReadResultSetInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt16NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int16` tensor. def testReadResultSetInt16MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int32` tensor. def testReadResultSetInt32(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # table and place it in an `int32` tensor. def testReadResultSetInt32VarCharColumnAsInt(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in an `int64` tensor. def testReadResultSetInt64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in a `uint8` tensor. def testReadResultSetUInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint8` tensors. def testReadResultSetUInt8MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in a `uint16` tensor. def testReadResultSetUInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint16` tensors. def testReadResultSetUInt16MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # in `bool` tensors. def testReadResultSetBool(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # from a SQLite database table and place it as `True` in a `bool` tensor. def testReadResultSetBoolNotZeroOrOne(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64OverlyPrecise(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 43067b4245d879aef9a40dc546b2a7742b3dc09c..e25570c5ad1e913c67c3c4339b3bdaf0523ccb04 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + def testPrefetchBufferUtilization(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + -1).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + float(i + 1)) + self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", + 0, 1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + 100) + def testReinitialize(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index 9a13acf8f0ac6690cad8847873768562da795496..2f5a44408fab5a686e5621660e7e3aca3e36954a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertLessEqual(min_value, value.histo.min) + self.assertGreaterEqual(max_value, value.histo.max) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasSum(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1d70b16041e902a5d08383887cbf647eac2e816c..4c3353fe4046d6b2bfabac580b46f88c8d7f2941 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase): # TODO(rachelim): support sparse tensor outputs next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: while True: try: op1 = sess.run(next1) @@ -52,9 +52,12 @@ class DatasetTestBase(test.TestCase): dataset2, exception_class, replacements=None): - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + # We are defining next1 and next2 in the same line so that we get identical + # file:line_number in the error messages + # pylint: disable=line-too-long + next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() + # pylint: enable=line-too-long + with self.cached_session() as sess: try: sess.run(next1) raise ValueError( diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 4b08ec759d8ad1363300de742fc92d9b9a41f363..8d335e87d549426f275768e874e4cbf466ed5dcc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) thread_ids = [] try: diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index d79a842e7a5d816e2e6a52fc83acbd6b260cf64b..f994c8563f6173a7d8943aaedc854a53e16dad24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_case, expected in test_cases: current_test_case = test_case sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index ff4d9b3260046b0c1747daba37ede7ffcddeba0c..6eaa0b195911acb057b30b8ca7408cdbfdce8352 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual) @@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredElement(structure, np.concatenate( ([5], shape), axis=0), dtype)) @@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredElement(None, np.concatenate(([5], shape), axis=0), @@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredSparseElement(structure, np.concatenate(([5], shape), axis=0), @@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredSparseElement(None, @@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( self._structuredElement( @@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( @@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shapes, dtype).apply(grouping.window_dataset( len(shapes))).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) @@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, @@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index c603ecc5ab27a711557376246b093fd5f80f8aec..867ee2ba3794df77df64b3346138cdffb526abdc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") def testWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ self.filename: self._createFile(), @@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteZLIB(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ @@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteGZIP(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 4b45cc7e36d14e99d1132b919dfc175a1217f8b9..a14781cd933e12ff1f04ad7ea26a923e2e1ef9e4 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -80,6 +80,7 @@ py_library( ":batching", ":gen_dataset_ops", ":interleave_ops", + ":optimization", ":parsing_ops", ":shuffle_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6edc1d79902c571b34b6a0a108c4d62cb6097ccb..099e10db921b78fc9fa3bcf73979ae6c33bc1972 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func, bucket_batch_sizes, padded_shapes=None, padding_values=None, - pad_to_bucket_boundary=False): + pad_to_bucket_boundary=False, + no_padding=False): """A transformation that buckets elements in a `Dataset` by length. Elements of the `Dataset` are grouped together by length and then are padded @@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func, unknown size to bucket boundary minus 1 (i.e., the maximum length in each bucket), and caller must ensure that the source `Dataset` does not contain any elements with length longer than `max(bucket_boundaries)`. + no_padding: `bool`, indicates whether to pad the batch features (features + need to be either of type `tf.SparseTensor` or of same shape). Returns: A `Dataset` transformation function, which can be passed to @@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func, def batching_fn(bucket_id, grouped_dataset): """Batch elements in dataset.""" - batch_size = batch_sizes[bucket_id] + batch_size = window_size_fn(bucket_id) + if no_padding: + return grouped_dataset.batch(batch_size) none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index fa1b851ad74bcf2cff69d42bce3eaa38822cd663..73840452dfd4578c0c37a60d3b3dc345ace996c6 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 + # TODO(jsimsa): Support RE matching for both individual transformation (e.g. to # account for indexing) and transformation sequence. @@ -46,6 +49,21 @@ def assert_next(transformations): return _apply_fn +def model(): + """A transformation that models performance. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _ModelDataset(dataset) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -97,6 +115,32 @@ class _AssertNextDataset(dataset_ops.Dataset): return self._input_dataset.output_types +class _ModelDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and models performance.""" + + def __init__(self, input_dataset): + """See `optimize()` for details.""" + super(_ModelDataset, self).__init__() + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.model_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4c466781f7f659e8d7e267500a118d482d76da15..785b39570706a60ab1ce4462d5b6d33adff6c964 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import optimization from tensorflow.contrib.data.python.ops import parsing_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops @@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat( return dataset -def make_tf_record_dataset( - file_pattern, - batch_size, - parser_fn=None, - num_epochs=None, - shuffle=True, - shuffle_buffer_size=None, - shuffle_seed=None, - prefetch_buffer_size=None, - num_parallel_reads=None, - num_parallel_parser_calls=None, - drop_final_batch=False): +def make_tf_record_dataset(file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): """Reads and optionally parses TFRecord files into a dataset. Provides common functionality such as batching, optional parsing, shuffling, @@ -300,8 +300,6 @@ def make_tf_record_dataset( parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, drop_remainder=drop_final_batch)) - if prefetch_buffer_size is None: - prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE if prefetch_buffer_size == 0: return dataset else: @@ -323,7 +321,7 @@ def make_csv_dataset( shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, num_parallel_reads=1, sloppy=False, num_rows_for_inference=100, @@ -386,9 +384,10 @@ def make_csv_dataset( shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. - prefetch_buffer_size: An int specifying the number of feature batches to - prefetch for performance improvement. Recommended value is the number of - batches consumed per training step. + prefetch_buffer_size: An int specifying the number of feature + batches to prefetch for performance improvement. Recommended value is the + number of batches consumed per training step. Defaults to auto-tune. + num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. sloppy: If `True`, reading performance will be improved at @@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, reader_num_threads=1, parser_num_threads=2, sloppy_ordering=False, @@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern, shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: Number of feature batches to prefetch in order to improve performance. Recommended value is the number of batches consumed - per training step (default is 1). + per training step. Defaults to auto-tune. reader_num_threads: Number of threads used to read `Example` records. If >1, the results will be interleaved. parser_num_threads: Number of threads to use for parsing `Example` tensors diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 30e1992c015d35859218d1b7fe3b2f3eb7c09b9b..91a27f97b7f75511db4b377220a353787beca30e 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -76,7 +76,7 @@ We then compile the Keras model and pass the `MirroredStrategy` object in the ```python model.compile(loss='mean_squared_error', optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=strategy) + distribute=distribution) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index c524d8b394afa664acf88f3e54eb125b061b2217..aaecbb0eb1fc889b4219991935dbc4f7410f27e4 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -485,7 +485,6 @@ py_library( srcs = ["single_loss_example.py"], deps = [ ":step_fn", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", @@ -708,19 +707,32 @@ cuda_py_test( ], ) -cuda_py_test( - name = "keras_test", +py_library( + name = "keras_test_lib", + testonly = 1, srcs = ["keras_test.py"], - additional_deps = [ - "//third_party/py/numpy", + deps = [ + ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", + "no_pip", "no_windows_gpu", "notsan", ], diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 2301ba9233d29a1e5d054e71e4d9383af8bd48fd..244d1fcec8ba481337afeede181c29d0552e3c44 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -50,10 +50,12 @@ from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib +from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent @@ -328,6 +330,10 @@ tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=5), required_tpu=True) +tpu_strategy_one_step = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy( + TPUClusterResolver(""), steps_per_run=1), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( @@ -343,17 +349,23 @@ mirrored_strategy_with_two_gpus = NamedDistribution( adam_optimizer_v1_fn = NamedObject( - "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) + "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) -optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn] +adagrad_optimizer_v1_fn = NamedObject( + "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) +optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, + adagrad_optimizer_v1_fn] adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) -optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn] +adagrad_optimizer_v2_fn = NamedObject( + "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) +optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn, + adagrad_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 049513463667e15d3b6e0cc57b84c6f828d1c215..a84ef041960e389c08246fc8a16df2300856d968 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -63,7 +63,6 @@ def get_input_datasets(): # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() - eval_ds = eval_ds.shuffle(100) eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 3cee3e37a702b94f0972fa56b0639b1c01e73667..5f35e381899a03f12cf0a6ed0168b9e500d41801 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -18,9 +18,12 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops @@ -31,6 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -63,6 +67,32 @@ def simple_functional_model(): return model +def multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') + dense = keras.layers.Dense(8, name='dense_1') + + interm_a = dense(input_a) + # Read m + interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m) + interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a]) + interm_b = dense(input_b) + merged = keras.layers.concatenate([interm_s, interm_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b, input_m], outputs=[output_c, output_d]) + model.compile( + loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + return model + + def get_ds_train_input_fn(): np.random.seed(_RANDOM_SEED) (x_train, y_train), _ = testing_utils.get_test_data( @@ -91,6 +121,68 @@ def get_ds_test_input_fn(): return dataset +def get_multi_inputs_multi_outputs_data(): + (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=3, + random_seed=_RANDOM_SEED) + (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=2, + random_seed=_RANDOM_SEED) + (m_train, _), (m_test, _) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(8,), + num_classes=2, + random_seed=_RANDOM_SEED) + + c_train = keras.utils.to_categorical(c_train) + c_test = keras.utils.to_categorical(c_test) + d_train = keras.utils.to_categorical(d_train) + d_test = keras.utils.to_categorical(d_test) + + train_data = { + 'input_a': a_train, + 'input_b': b_train, + 'input_m': m_train, + 'output_c': c_train, + 'output_d': d_train + } + test_data = { + 'input_a': a_test, + 'input_b': b_test, + 'input_m': m_test, + 'output_c': c_test, + 'output_d': d_test + } + + return (train_data, test_data) + + +def batch_wrapper(dataset, batch_size, distribution): + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy_one_step], + mode=['graph']) + + class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): @@ -99,6 +191,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + self._dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() @@ -152,6 +246,53 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + train_data, test_data = get_multi_inputs_multi_outputs_data() + + def train_input_fn(): + input_dict = { + 'input_a': train_data['input_a'], + 'input_b': train_data['input_b'], + 'input_m': train_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': train_data['output_c'], + 'dense_3': train_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + def eval_input_fn(): + input_dict = { + 'input_a': test_data['input_a'], + 'input_b': test_data['input_b'], + 'input_m': test_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': test_data['output_c'], + 'dense_3': test_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + self.do_test_multi_inputs_multi_outputs_with_input_fn( + train_input_fn, eval_input_fn) + + def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, + eval_input_fn): + config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=self._dist) + with self.cached_session(): + model = multi_inputs_multi_outputs_model() + est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) + baseline_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(eval_results['loss'], baseline_eval_results['loss']) + def test_keras_optimizer_with_distribution_strategy(self): dist = mirrored_strategy.MirroredStrategy( devices=['/device:GPU:0', '/device:GPU:1']) @@ -175,7 +316,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._config.model_dir) -class TestWithDistributionStrategy(test.TestCase): +class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_validating_dataset_input_tensors_with_shape_mismatch(self): with self.cached_session(): @@ -215,7 +356,7 @@ class TestWithDistributionStrategy(test.TestCase): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y) - def test_calling_model_on_same_dataset(self): + def test_calling_model_with_numpy_arrays(self): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -228,11 +369,44 @@ class TestWithDistributionStrategy(test.TestCase): '/device:GPU:0']) model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(all_combinations()) + def test_calling_model_on_same_dataset(self, distribution): + with self.cached_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) # Call fit with validation data model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, @@ -241,6 +415,9 @@ class TestWithDistributionStrategy(test.TestCase): validation_data=dataset, validation_steps=2) model.predict(dataset, steps=2) + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. def test_fit_with_tuple_and_dict_dataset_inputs(self): with self.cached_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -282,7 +459,8 @@ class TestWithDistributionStrategy(test.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - def test_fit_eval_and_predict_methods_on_dataset(self): + @combinations.generate(all_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -291,16 +469,13 @@ class TestWithDistributionStrategy(test.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) @@ -496,6 +671,8 @@ class TestWithDistributionStrategy(test.TestCase): class LossMaskingWithDistributionStrategyTest(test.TestCase): + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. def test_masking(self): with self.cached_session(): np.random.seed(1337) @@ -519,24 +696,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): self.assertEqual(hist.history['loss'][0], 0) -class NormalizationLayerWithDistributionStrategyTest(test.TestCase): +class NormalizationLayerWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): - def test_batchnorm_correctness(self): + @combinations.generate(all_combinations()) + def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) model.add(norm) - strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', - '/device:GPU:0']) model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) dataset = dataset.repeat(100) - dataset = dataset.batch(32) + dataset = batch_wrapper(dataset, 32, distribution) model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) out = model.predict(dataset, steps=2) @@ -546,9 +724,11 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class CorrectnessWithDistributionStrategyTest(test.TestCase): +class CorrectnessWithDistributionStrategyTest(test.TestCase, + parameterized.TestCase): - def test_correctness(self): + @combinations.generate(all_combinations()) + def test_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') num_samples = 10000 @@ -557,43 +737,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): x_train = x_train.astype('float32') y_train = y_train.astype('float32') - model = keras.Sequential() - model.add(keras.layers.Dense(1, input_shape=(1,))) - - # With DistributionStrategy - dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_with = dataset_with.batch(32) - strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', - '/device:GPU:0']) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - distribute=strategy) - model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) - wts_with_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, - x_predict)) - predict_dataset_with = predict_dataset_with.batch(2) - predict_with_ds = model.predict(predict_dataset_with, steps=1) - predict_with_ds = np.reshape(predict_with_ds, (4, 1)) - - # Without DistributionStrategy - dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + def fit_and_predict(with_distribution=None): + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=with_distribution) + + batch_size = 64 + if with_distribution: + batch_size //= with_distribution.num_towers + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_without = dataset_without.batch(64) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5)) - model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) - wts_without_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( - x_predict, x_predict)) - predict_dataset_without = predict_dataset_without.batch(4) - predict_without_ds = model.predict(predict_dataset_without, steps=1) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + # Running only 100 steps instead of the full dataset to keep test + # duration small. + model.fit(x=train_dataset, epochs=1, steps_per_epoch=100) + + weights = model.get_weights() + + x_predict = [[1.], [2.], [3.], [4.]] + predict_batch_size = 4 + if with_distribution: + predict_batch_size //= with_distribution.num_towers + predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, distribution) + predict_result = model.predict(predict_dataset, steps=1) + predict_result = np.reshape(predict_result, (4, 1)) + + return weights, predict_result + + wts_with_ds, predict_with_ds = fit_and_predict( + with_distribution=distribution) + wts_without_ds, predict_without_ds = fit_and_predict( + with_distribution=None) # Verify that the weights are the same within some limits of tolerance. np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) @@ -602,5 +782,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) +# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index bdac4fb58c2ca8c4f6a322a6f477a9e3657b8f93..ba147e78241e5ab45809e498e00debd45a2c49b4 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -183,6 +183,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): "dense/kernel", "dense/bias", "beta1_power", "beta2_power", "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", "dense/bias/Adam_1" + ], + "Adagrad": [ + "dense/kernel/Adagrad", "dense/kernel", + "dense/bias/Adagrad", "dense/bias" ] } variables = variables_map[optimizer_fn().get_name()] diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py index bb10b546a1907bba26cd0d7e7c5308420adbaf3f..16799104e8112f4391152c0cf2a15af81f8c2c9d 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -55,14 +55,14 @@ class PrefetchingOpsV2Test(test.TestCase): next_element = iterator.get_next() output = [] + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: - for _ in range(5): + for _ in range(4): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) - self.assertEquals(set(range(10)), set(output)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + self.assertEquals(set(range(8)), set(output)) def testPrefetchToTwoDevicesWithReinit(self): if not test_util.is_gpu_available(): @@ -75,14 +75,14 @@ class PrefetchingOpsV2Test(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: sess.run(iterator.initializer) - for _ in range(5): - sess.run(next_element) - with self.assertRaises(errors.OutOfRangeError): + for _ in range(4): sess.run(next_element) sess.run(iterator.initializer) - for _ in range(5): + for _ in range(4): sess.run(next_element) diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 5aa19cf6a9f8411120ed929cecaf93dda6c9edf2..09b351ffa4165656e2fc9666ab4b7725ef061f50 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn, def dataset_fn(): dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() - # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be + # TODO(isaprykin): batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. - return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) + return dataset.batch(1, drop_remainder=True) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 97c53ae2b94988ad9938c9d1cf3326e4076e8d6f..9aadc634da5a7591747a4f651cdb45376393402d 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -166,6 +166,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], + tags = ["notap"], ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index a7bd51430e384c199ca8abd06ef9887e998cc380..1e36b7ff9be4018f6b80a89e5967e5e21e9bd275 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator +from tensorflow.python.ops.linalg import linalg from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 196cc413353657c2dfadd3a1c87b97518c6f235b..13370497ce706a60b1d0c7f4f148076b354626a7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -22,7 +22,6 @@ import numpy as np from scipy import stats from tensorflow.contrib import distributions -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -30,6 +29,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.platform import test bs = bijectors diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 25f29452c3949600b8a4153a8585dd7269bd3b2b..ba31697c589006c9fbee2fe68639e5f1daf51f62 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import dtypes @@ -29,6 +28,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 6959b3e8775d2dd488b4ee3252d143ef376d58f9..b4ad33cf6dbf073419a27f378c8eefdba97c5af7 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond @@ -27,6 +26,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.distributions import distribution as distribution_lib # The following two lines are redundant, in a sense. The first enables diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index d8401801f21afbe8fd042053c6a38a31a2539438..74d9d04fc702a90a5fc5a31f554abe257dd2860d 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d9110947ecdbba1a63669573f46db17b02e512ab..c6a23e4336fffbf7b61490dd3468bc71c7f421cc 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index f1accaaa4c920344608015c792a2c3606de1337f..49b9de0ab508f5db090bb1349f596da1b2a71b49 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -21,7 +21,6 @@ from __future__ import print_function import math import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import constant_op @@ -36,6 +35,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation __all__ = [ diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 7d2274db9b051e604266074651f4cbd331f20f48..48d093e0754f79725f3e3e900320773aae41e8ad 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"].numpy()) def testDatasetGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) init_op, call_op, results_op = e.evaluate_on_dataset(ds) @@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"]) def testWriteSummariesGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 529c99b37c7c37e70afe0d95ccca15200afce60b..3acecd283cda83992bab0c37cf0b8037ed2cf27a 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 40bc09872482c6062a870a3c274ba792ab83f3de..e0d5e494d432b365b0d1dcff6b634de2e6213a43 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index f1e1f99c57a77a6c6d3cb0578e1f1c776933605d..560fc8c5a22a0e7acf1f37cf7daf7790dc14de19 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md index fabd7b3e206d3a1954893a2b75361146d4709d00..750bbc66f3555a5d30ac1fd81d87ff54f7389f64 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md @@ -23,4 +23,4 @@ Attribution-ShareAlike License and is available at https://en.wikipedia.org/wiki/List_of_colors:_N-Z This example was adapted from - https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot + https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD deleted file mode 100644 index 638c57d1c92c1dce0ef9e73e9a6ac2369358080b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -cuda_py_test( - name = "scan_test", - size = "small", - srcs = ["scan_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "scan_graph_test", - size = "small", - srcs = ["scan_graph_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py deleted file mode 100644 index d4b8c8941ec411912f3089315d038fc4bcd049ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under graph mode execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - with tf.Session() as sess: - sess.run(sum_op) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py deleted file mode 100644 index a02fc24c79dae6c2565db8b138b1d7391d169ed8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under eager execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - - -if __name__ == '__main__': - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index dcc7b71d79f207019cec4425eb000b92420b9ca7..9d2d172752c7f3f3ee6eaa11ab8952313a4a3543 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -216,7 +216,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.numer.name, "has_space/numer:0") def testGraphWithPlaceholder(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) accumulate = m(p) @@ -309,7 +309,7 @@ class MetricsTest(test.TestCase): self.assertTrue(old_numer is m.numer) def testMetricsChain(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): m1 = metrics.Mean() m2 = metrics.Mean(name="m2") update_m2 = m2(3.0) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 437b3d965dc841f9c105865bf0df3b321119146b..6db311d52de61359995087fb5ca3d5461f74c4c1 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":dnn_with_layer_annotations", ":early_stopping", ":export", ":exporter", @@ -126,6 +127,61 @@ py_test( ], ) +py_library( + name = "dnn_with_layer_annotations", + srcs = ["python/estimator/dnn_with_layer_annotations.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:summary", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:utils", + ], +) + +py_test( + name = "dnn_with_layer_annotations_test", + size = "medium", + srcs = ["python/estimator/dnn_with_layer_annotations_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", # b/67510291 + ], + deps = [ + ":dnn_with_layer_annotations", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:pandas_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + py_library( name = "dnn_linear_combined", srcs = ["python/estimator/dnn_linear_combined.py"], diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 258860f26340a0934e854f2d1950ead60e413234..78914ecacaf79fd25b33d4159601ab49d2b74c96 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * @@ -76,6 +77,8 @@ _allowed_symbols = [ 'build_raw_supervised_input_receiver_fn', 'build_supervised_input_receiver_fn_from_input_fn', 'SavedModelEstimator' + 'DNNClassifierWithLayerAnnotations', + 'DNNRegressorWithLayerAnnotations', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..152431d1b205845945cc2c079b747f81d739026f --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -0,0 +1,434 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Deep Neural Network estimators with layer annotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import pickle + +from google.protobuf.any_pb2 import Any + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses +from tensorflow.python.saved_model import utils as saved_model_utils + + +class LayerAnnotationsCollectionNames(object): + """Names for the collections containing the annotations.""" + + UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features' + PROCESSED_FEATURES = 'layer_annotatons/processed_features' + FEATURE_COLUMNS = 'layer_annotations/feature_columns' + + @classmethod + def keys(cls, collection_name): + return '%s/keys' % collection_name + + @classmethod + def values(cls, collection_name): + return '%s/values' % collection_name + + +def serialize_feature_column(feature_column): + if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access + # We can't pickle nested functions, and we don't need the value of + # layer_creator in most cases anyway, so just discard its value. + args = feature_column._asdict() + args['layer_creator'] = None + temp = type(feature_column)(**args) + return pickle.dumps(temp) + return pickle.dumps(feature_column) + + +def _to_any_wrapped_tensor_info(tensor): + """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`.""" + any_buf = Any() + tensor_info = saved_model_utils.build_tensor_info(tensor) + any_buf.Pack(tensor_info) + return any_buf + + +def make_input_layer_with_layer_annotations(original_input_layer, mode): + """Make an input_layer replacement function that adds layer annotations.""" + + def input_layer_with_layer_annotations(features, + feature_columns, + weight_collections=None, + trainable=True, + cols_to_vars=None, + cols_to_output_tensors=None): + """Returns a dense `Tensor` as input layer based on given `feature_columns`. + + Generally a single example in training data is described with + FeatureColumns. + At the first layer of the model, this column oriented data should be + converted + to a single `Tensor`. + + This is like tf.feature_column.input_layer, except with added + Integrated-Gradient annotations. + + Args: + features: A mapping from key to tensors. `_FeatureColumn`s look up via + these keys. For example `numeric_column('price')` will look at 'price' + key in this dict. Values can be a `SparseTensor` or a `Tensor` depends + on corresponding `_FeatureColumn`. + feature_columns: An iterable containing the FeatureColumns to use as + inputs to your model. All items should be instances of classes derived + from `_DenseColumn` such as `numeric_column`, `embedding_column`, + `bucketized_column`, `indicator_column`. If you have categorical + features, you can wrap them with an `embedding_column` or + `indicator_column`. + weight_collections: A list of collection names to which the Variable will + be added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + cols_to_vars: If not `None`, must be a dictionary that will be filled with + a mapping from `_FeatureColumn` to list of `Variable`s. For example, + after the call, we might have cols_to_vars = {_EmbeddingColumn( + categorical_column=_HashedCategoricalColumn( key='sparse_feature', + hash_bucket_size=5, dtype=tf.string), dimension=10): [