diff --git a/README.md b/README.md index 6fb4486d0de9ff476b5cf1dbd63d66879637df84..63853137cfd30b396f8c7d204811f3e4a1794c07 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines diff --git a/RELEASE.md b/RELEASE.md index 18e5dfb16e9ef462a55f79ab73a8cfab9387abe5..e09e9c6190f57adec67c2ae1d85848dabfd9c2a7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -22,7 +22,7 @@ * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. * Accelerated Linear Algebra (XLA): -* TensorFlow Debugger (tfdbg) CLI: +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). * `tf.contrib`: * Add `tf.contrib.data.choose_from_datasets()`. * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. diff --git a/SECURITY.md b/SECURITY.md index e2f6ff353a3c04a6ec6b8ccbaeb75db59fa22d54..0b52fdc7ab84b7bd5bce5d247ede81b40699005c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known Vulnerabilities For a list of known vulnerabilities and security advisories for TensorFlow, -(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here]. +[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md). diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9b07669a5d8e4da6ce202fc9196185b91d8e7e2e..4e212e96dcfe4ad2b2055ea9abb150e9fd5c1f28 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -154,6 +154,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_s390x", + values = {"cpu": "s390x"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -424,6 +430,22 @@ filegroup( data = glob(["docs_src/**/*.md"]), ) +cc_library( + name = "grpc", + deps = select({ + ":linux_s390x": ["@grpc//:grpc_unsecure"], + "//conditions:default": ["@grpc"], + }), +) + +cc_library( + name = "grpc++", + deps = select({ + ":linux_s390x": ["@grpc//:grpc++_unsecure"], + "//conditions:default": ["@grpc//:grpc++"], + }), +) + # A shared object which includes registration mechanisms for ops and # kernels. Does not include the implementations of any ops or kernels. Instead, # the library which loads libtensorflow_framework.so @@ -451,6 +473,15 @@ filegroup( tf_cc_shared_object( name = "libtensorflow_framework.so", framework_so = [], + linkopts = select({ + "//tensorflow:darwin": [], + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//conditions:default": [ + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow:tf_framework_version_script.lds)", + ], + }), linkstatic = 1, visibility = ["//visibility:public"], deps = [ @@ -460,6 +491,7 @@ tf_cc_shared_object( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/stream_executor:stream_executor_impl", + "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) @@ -539,14 +571,17 @@ exports_files( ) gen_api_init_files( - name = "python_api_gen", + name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], root_init_template = "api_template.__init__.py", ) py_library( name = "tensorflow_py", - srcs = [":python_api_gen"], + srcs = [ + ":tensorflow_python_api_gen", + "//tensorflow/python/estimator/api:estimator_python_api_gen", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 9b0d7d48afd058607badc90b95c9dca0c4ceaa31..9662d7b478ba61c69edc20b0d47293f9939e7881 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -22,7 +22,22 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # API IMPORTS PLACEHOLDER -from tensorflow.python.util.lazy_loader import LazyLoader +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b86b277ac3200b88ae03490a6c1b64d464e81950..12f0d8bff4720d98b7f45b113dc62c881e32a399 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { @@ -2108,7 +2123,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2138,7 +2153,7 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 27ff5f7211b0592637a173d337f93c10d376443f..992d1afd5fcb0641794bb2abbe5ab20a287d3b62 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + + // Handles are on task0, but op is on remote (task1). + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ab8cd8f4bcd3b5a102692b47cfedfce6a9d9cc47..8c74014614789758192691ee065f92759a113a7a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -181,6 +181,7 @@ cc_library( "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", ], ) @@ -316,7 +317,6 @@ cc_library( ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", - "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", @@ -342,6 +342,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], ) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 6d1e3325ebd35b9608ea273fb7de39bad381e60d..9448b8ebde09b73bf26fd8c5ad118105045ff452 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -107,41 +106,11 @@ void MarkGuaranteedConstants( } } -// A node/slot pair. -// TODO(phawkins): is there a common definition of this? -struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot) - : node(node), slot(slot), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot, DataType dtype) - : node(node), slot(slot), dtype(dtype) {} - - const Node* node; - int slot; - - // Optional: used to record the destination type of a source NodeSlot in case - // the source output is a Ref type that is cast to a Tensor at the - // destination. - DataType dtype; - - bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot && dtype == other.dtype; - } - - // Leave dtype out of the hash since there are never two NodeSlots with the - // same node and slot and different dtypes. - struct Hasher { - uint64 operator()(NodeSlot const& s) const { - return Hash64Combine(std::hash()(s.node), - std::hash()(s.slot)); - } - }; - - struct PairHasher { - uint64 operator()(std::pair const& s) const { - return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); - } - }; +struct OutputInputTensorPairHasher { + uint64 operator()(std::pair const& s) const { + return Hash64Combine(OutputTensor::Hash()(s.first), + InputTensor::Hash()(s.second)); + } }; // TODO(phawkins) add a canonical copy of these operator names and refactor @@ -182,8 +151,7 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, - FunctionLibraryDefinition* library); + Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -271,7 +239,7 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( @@ -284,11 +252,9 @@ class Encapsulator { // Subgraph. void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that inputs to the function should be wired up to. - Node* GetCallNodeForInputs() const; - - // Returns the Node that outputs to the function should be wired up to. - Node* GetCallNodeForOutputs() const; + // Returns the Node that the inputs and outputs of the function should be + // wired up to. + Node* GetCallNode() const; // Returns the index of the arg that the dst of edge should connect to. int GetArgIndexForEdge(const Edge* edge) const; @@ -380,7 +346,7 @@ class Encapsulator { // Map from source (producer node/slot) tensors in the original graph to // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. - std::unordered_map inputs; + std::unordered_map inputs; // Set of nodes in the original graph that are the source of control edges // that cross from the containing compiled subgraph into the @@ -396,8 +362,15 @@ class Encapsulator { // node/slot) tensors in the original graph to output index (slot number // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. - std::unordered_map outputs_by_src; - std::unordered_map outputs_by_dst; + struct ArgNumAndType { + int index; + DataType dtype; + + ArgNumAndType(int i, DataType t) : index(i), dtype(t) {} + }; + std::unordered_map + outputs_by_src; + std::unordered_map outputs_by_dst; // Set of nodes in the original graph that are the destination of control // edges that cross from the outside_compilation subgraph into the @@ -425,12 +398,6 @@ class Encapsulator { OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( const string& outside_compilation_id); - // Builds a ParallelCheck op that compares the output of the original - // subgraph with the encapsulated subgraph. - Status BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out); - // Builds a placeholder node used to provide the key input to a RecvAtHost // or SendFromHost node. This placeholder node will be removed by a later // pass. @@ -482,26 +449,21 @@ class Encapsulator { // Not owned. Node* host_compute_key_placeholder_ = nullptr; - // Function call node(s) in the output graph. Not owned. - // If parallel_checking is enabled, 'call_node_inputs' is the function call - // node to which inputs should be fed, and 'call_node_outputs' is the - // parallel check op from which outputs should be read. If parallel checking - // is disabled, both point to the function call node. - Node* call_node_inputs_; - Node* call_node_outputs_; + // Function call node in the output graph. Not owned. + Node* call_node_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src_; - std::unordered_map args_by_dst_; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; - // The _Arg nodes in the subgraph, in order by argument number. + // The arguments to the subgraph, in order. std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results_; + std::unordered_map results_; // The outside_compilation clusters in this subgraph. std::unordered_map @@ -541,13 +503,12 @@ class Encapsulator { // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images); + Graph* graph_out, std::unordered_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all // outside_compilation subgraphs. @@ -598,9 +559,9 @@ class Encapsulator { const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added); + Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added); // Adds control dependencies between subgraph call nodes that have // dependencies via outside_compilation edges. @@ -609,7 +570,7 @@ class Encapsulator { // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Constructs a minimal shape inference graph that can be used to determine // the shape of send_node at the time that the subgraph is compiled. @@ -729,20 +690,14 @@ void TopologicalClusterSort( } // namespace -Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { - return call_node_inputs_; -} - -Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { - return call_node_outputs_; -} +Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { - return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); + return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input())); } int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { - return results_.at(NodeSlot(edge->src(), edge->src_output())); + return results_.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetRecvAtHostNode( @@ -754,7 +709,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode( int Encapsulator::Subgraph::GetRecvAtHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .inputs.at(NodeSlot(edge->src(), edge->src_output())); + .inputs.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetSendFromHostNode( @@ -766,7 +721,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode( int Encapsulator::Subgraph::GetSendFromHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input())); } Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { @@ -791,10 +746,10 @@ Status Encapsulator::Subgraph::RecordArg( std::vector>* src_arg_pairs) { Node* src_node = edge->src(); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; - std::tie(iter, inserted) = - args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + std::tie(iter, inserted) = args_by_src_.emplace( + OutputTensor(src_node, src_slot), args_by_src_.size()); int arg_index = iter->second; if (inserted) { NodeDef arg_def; @@ -815,7 +770,7 @@ Status Encapsulator::Subgraph::RecordArg( Node* dst_node = edge->dst(); Node* dst_image = node_images.at(dst_node); int dst_slot = edge->dst_input(); - args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index; graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); return Status::OK(); } @@ -826,10 +781,10 @@ Status Encapsulator::Subgraph::RecordResult( Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; std::tie(iter, inserted) = - results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + results_.emplace(OutputTensor(src_node, src_slot), results_.size()); int ret_index = iter->second; if (inserted) { NodeDef ret_def; @@ -867,8 +822,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( outside_subgraph->control_inputs.insert(edge->src()); } else { int input_index = outside_subgraph->inputs.size(); - outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + outside_subgraph->inputs.emplace( + OutputTensor(edge->src(), edge->src_output()), input_index); } } @@ -882,11 +837,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = outside_subgraph->outputs_by_src - .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph->outputs_by_src.size()) + .emplace(OutputTensor(edge->src(), edge->src_output()), + OutsideCompilationSubgraph::ArgNumAndType( + outside_subgraph->outputs_by_src.size(), dtype)) .first; - int output_index = output_iter->second; - outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + const int output_index = output_iter->second.index; + outside_subgraph + ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] = output_index; } } @@ -968,7 +925,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); @@ -976,8 +933,8 @@ Status Encapsulator::Subgraph::AddHostComputes( input_dtypes[input_index] = dtype; } for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; + DataType dtype = output.second.dtype; + int output_index = output.second.index; output_dtypes[output_index] = dtype; } @@ -1015,7 +972,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; graph_->AddEdge(src_image, src_slot, host_compute, input_index); } @@ -1037,7 +994,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); - int dst_slot = output.first.slot; + int dst_slot = output.first.index; int output_index = output.second; graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); @@ -1075,7 +1032,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_inputs_); + graph_out->AddControlEdge(sequencer_, call_node_); } } @@ -1090,14 +1047,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef( call_node_def_.set_device(device_); if (rewrite_subgraph_fn) { + std::vector arg_source_tensors(args_by_src_.size()); + for (const auto& arg : args_by_src_) { + arg_source_tensors.at(arg.second) = arg.first; + } // Initialize the input and output permutations to the identity. std::vector input_permutation(args_by_src_.size()); std::iota(input_permutation.begin(), input_permutation.end(), 0); std::vector output_permutation(results_.size()); std::iota(output_permutation.begin(), output_permutation.end(), 0); - TF_RETURN_IF_ERROR(rewrite_subgraph_fn( - &graph_, &input_permutation, &output_permutation, &call_node_def_)); + TF_RETURN_IF_ERROR( + rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation, + &output_permutation, &call_node_def_)); // Apply the input/output permutations to the 'args_by_...' and 'results_' // mappings, so when we build edges in BuildOutputGraph() we @@ -1200,83 +1162,16 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out) { - // Build an index mapping output positions to node/slot pairs in the - // original graph. - std::vector results_by_num(results_.size()); - for (const auto& entry : results_) { - results_by_num[entry.second] = entry.first; - } - - // Build a parallel check NodeDef. - int num_results = results_by_num.size(); - std::vector result_dtypes(num_results); - std::vector expected_outputs(num_results); - std::vector actual_outputs(num_results); - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - result_dtypes[i] = node_slot.node->output_type(node_slot.slot); - expected_outputs[i] = - NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), - node_slot.slot, result_dtypes[i]); - actual_outputs[i] = - NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); - } - // Assign the parallel check op to a CPU on the same task as the cluster it is - // checking. - string device, dummy; - if (!DeviceNameUtils::SplitDeviceName( - call_node_inputs_->assigned_device_name(), &device, &dummy)) { - return errors::InvalidArgument("Could not parse device name"); - } - strings::StrAppend(&device, "/cpu:0"); - - NodeDef check_def; - TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), - "_parallel_check")), - "ParallelCheck") - .Device(device) - .Attr("T", result_dtypes) - .Input(expected_outputs) - .Input(actual_outputs) - .Finalize(&check_def)); - - Status s; - Node* check_op = graph_out->AddNode(check_def, &s); - if (!s.ok()) return s; - check_op->set_assigned_device_name(device); - - // TODO(phawkins): it seems redundant to call AddEdge as well as - // pass Inputs to the NodeDefBuilder, but I have been unable to find a - // way to avoid it. - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, - i); - graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); - } - - call_node_outputs_ = check_op; - return Status::OK(); -} - Status Encapsulator::Subgraph::AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { Status s; - call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + call_node_ = graph_out->AddNode(call_node_def_, &s); if (!s.ok()) return s; // Copy the assigned device and the key_annotation over. - call_node_inputs_->set_assigned_device_name(device_); - call_node_outputs_ = call_node_inputs_; + call_node_->set_assigned_device_name(device_); - if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); - } return Status::OK(); } @@ -1315,7 +1210,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( for (const auto& input : oc_subgraph->inputs) { const Node* src_node = input.first.node; - int src_slot = input.first.slot; + int src_slot = input.first.index; int input_index = input.second; DataType dtype = src_node->output_type(src_slot); @@ -1369,8 +1264,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( for (const auto& output : oc_subgraph->outputs_by_src) { const Node* src_node = output.first.node; Node* src_image = node_images.at(src_node); - int src_slot = output.first.slot; - int output_index = output.second; + int src_slot = output.first.index; + int output_index = output.second.index; DataType dtype = src_node->output_type(src_slot); dtypes[output_index] = dtype; @@ -1627,27 +1522,17 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images) { + Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; string outside_compilation_id; TF_RETURN_IF_ERROR( GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) - continue; + // Don't copy nodes that are going to be encapsulated. + if (IsInSubgraph(func_id, outside_compilation_id)) continue; Node* image = graph_out->CopyNode(node); - if (!outside_compilation_id.empty()) { - if (parallel_checking) { - return errors::InvalidArgument( - "Parallel checking is not supported when outside_compilation " - "clusters are present."); - } - } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -1657,10 +1542,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Status Encapsulator::AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( - node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } return Status::OK(); } @@ -1694,7 +1579,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( } else { // The edge is from a subgraph to a regular node in the output graph so // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } } else { // The source of the edge is in the output graph so use the node image in @@ -1742,7 +1627,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( } else { // The edge is to a subgraph from a regular node in the output graph so // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } } else { // The destination of the edge is in the output graph so use the node image @@ -1778,10 +1663,9 @@ Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added) { + const std::unordered_map& node_images, Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( src_func_id, src_outside_compilation_id, dst_func_id, @@ -1796,16 +1680,12 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edge->IsControlEdge()) { // Add the control edge, if we have not already added it, using the images // determined above (potentially call operators or RecvAtHost/SendFromHost). - if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + if (edges_added + ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { graph_out->AddControlEdge(src_image, dst_image); } - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } return Status::OK(); } @@ -1817,18 +1697,10 @@ Status Encapsulator::CopyEdgeToOutputGraph( FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && - parallel_checking) { - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - // Add the edge, if we have not already added it. if (edges_added - ->emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) + ->emplace(OutputTensor(src_image, src_output), + InputTensor(dst_image, dst_input)) .second) { graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } @@ -1839,8 +1711,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { for (const auto& ancestors : subgraph_ancestors_) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { - graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(), - subgraphs_[subgraph].GetCallNodeForInputs()); + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), + subgraphs_[subgraph].GetCallNode()); } } return Status::OK(); @@ -1848,11 +1720,12 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. - std::unordered_set, NodeSlot::PairHasher> + std::unordered_set, + OutputInputTensorPairHasher> edges_added; for (const Edge* edge : graph_in_->edges()) { @@ -1870,16 +1743,6 @@ Status Encapsulator::AddEdgesToOutputGraph( if (IsInSubgraph(src_func_id, src_outside_compilation_id) && IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { - if (parallel_checking) { - Node* src_image = node_images.at(edge->src()); - Node* dst_image = node_images.at(edge->dst()); - if (edge->IsControlEdge()) { - graph_out->AddControlEdge(src_image, dst_image); - } else { - graph_out->AddEdge(src_image, edge->src_output(), dst_image, - edge->dst_input()); - } - } continue; } @@ -1887,8 +1750,7 @@ Status Encapsulator::AddEdgesToOutputGraph( // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, parallel_checking, graph_out, - &edges_added)); + dst_outside_compilation_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { @@ -2504,18 +2366,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, +Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; - TF_RETURN_IF_ERROR( - CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); - TF_RETURN_IF_ERROR( - AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); + TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); - TF_RETURN_IF_ERROR( - AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); TF_RETURN_IF_ERROR( GetShapeInfoForOutsideCompilationSends(graph_out, library)); @@ -2528,8 +2387,8 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, Status EncapsulateSubgraphsInFunctions( string group_attribute, string outside_compilation_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool parallel_checking, bool reuse_existing_functions, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), @@ -2543,8 +2402,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); + TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); return Status::OK(); @@ -2585,8 +2443,6 @@ static Status RenumberArguments(Graph* graph, Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; - legacy_flags::EncapsulateSubgraphsPassFlags* flags = - legacy_flags::GetEncapsulateSubgraphsPassFlags(); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, options.flib_def); @@ -2602,68 +2458,70 @@ Status EncapsulateSubgraphsPass::Run( FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - auto rewrite_subgraph = [flr](std::unique_ptr* subgraph, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* node) { - // Optimize the subgraph. - OptimizeGraph(flr, subgraph); - - const int num_args = input_permutation->size(); - std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); - - DataTypeVector arg_types(num_args); - TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); - - // Compute a permutation of the arguments such that the constant arguments - // are first. - const int num_consts = - std::count(const_args.begin(), const_args.end(), true); - - const int num_resources = - std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); - const int num_nonconsts = num_args - num_resources - num_consts; - if (num_nonconsts < 0) { - return errors::Internal("num_nonconsts should be >= 0, was ", - num_nonconsts); - } + auto rewrite_subgraph = + [flr](const std::vector& arg_source_tensors, + std::unique_ptr* subgraph, + std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node) { + // Optimize the subgraph. + OptimizeGraph(flr, subgraph); + + const int num_args = input_permutation->size(); + std::vector const_args(num_args); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + + DataTypeVector arg_types(num_args); + TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); + + // Compute a permutation of the arguments such that the constant + // arguments are first. + const int num_consts = + std::count(const_args.begin(), const_args.end(), true); + + const int num_resources = + std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); + const int num_nonconsts = num_args - num_resources - num_consts; + if (num_nonconsts < 0) { + return errors::Internal("num_nonconsts should be >= 0, was ", + num_nonconsts); + } - int const_pos = 0; - int arg_pos = num_consts; - int resource_pos = num_consts + num_nonconsts; - for (int i = 0; i < num_args; ++i) { - if (const_args[i]) { - if (arg_types[i] == DT_RESOURCE) { - return errors::Internal( - "Resource arguments cannot be constant (argument ", i, ")"); + int const_pos = 0; + int arg_pos = num_consts; + int resource_pos = num_consts + num_nonconsts; + for (int i = 0; i < num_args; ++i) { + if (const_args[i]) { + if (arg_types[i] == DT_RESOURCE) { + return errors::Internal( + "Resource arguments cannot be constant (argument ", i, ")"); + } + (*input_permutation)[i] = const_pos; + ++const_pos; + } else if (arg_types[i] == DT_RESOURCE) { + (*input_permutation)[i] = resource_pos; + ++resource_pos; + } else { + (*input_permutation)[i] = arg_pos; + ++arg_pos; + } } - (*input_permutation)[i] = const_pos; - ++const_pos; - } else if (arg_types[i] == DT_RESOURCE) { - (*input_permutation)[i] = resource_pos; - ++resource_pos; - } else { - (*input_permutation)[i] = arg_pos; - ++arg_pos; - } - } - // Renumber argument nodes in the graph. - TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); + // Renumber argument nodes in the graph. + TF_RETURN_IF_ERROR( + RenumberArguments(subgraph->get(), *input_permutation)); - // TODO(phawkins): add a forward is-constant analysis, similarly split - // outputs into host-memory constants and device-memory non-constants. + // TODO(phawkins): add a forward is-constant analysis, similarly split + // outputs into host-memory constants and device-memory non-constants. - AddNodeAttr(kXlaCompiledKernelAttr, true, node); - AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); - AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); - return Status::OK(); - }; + AddNodeAttr(kXlaCompiledKernelAttr, true, node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); + return Status::OK(); + }; TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, flags->tf_xla_parallel_checking, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 5fee36f022a7515504cb6faa5cca658481b784c5..926589546fec72048485d30966f31b24e44b1245 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { // A rewriting function to apply to each subgraph during encapsulation. +// 'arg_source_tensors' are the tensors corresponding to the arguments in the +// original source graph (*not* 'graph'). +// // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; // 'input_permutation' is a mapping from old argument numbers to new argument // numbers, whereas 'output_permutation' is the same for outputs. Both @@ -37,6 +40,7 @@ namespace tensorflow { // The rewrite may also change the NodeDef's operator name, and that // name will be used as the name of the generated function. typedef std::function& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> RewriteSubgraphFn; @@ -61,10 +65,6 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index eef113a3547f0b2f648680d5f51650f70dbbd261..4eb389e0c653f2d32c17f448687f865a44a11b96 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -511,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -560,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* b = Input(b1.opts().WithName("B")); // Give nodes 'c' and 'd' names that collide after lowercasing. Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); - Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( - "_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } @@ -614,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* c = Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( "_encapsulate", "F1")); - Node* d = - Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + Node* d = Binary(b, c, + b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -707,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; @@ -721,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -TEST(EncapsulateSubgraphsTest, ParallelChecking) { - Scope root = Scope::NewRootScope().ExitOnError().WithDevice( - "/job:localhost/replica:0/task:0/cpu:0"); - auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); - auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); - add1.node()->AddAttr("_cluster", "cluster1"); - auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); - add2.node()->AddAttr("_cluster", "cluster1"); - auto out = ops::Mul(root.WithOpName("mul"), x1, add2); - - Graph graph_before_encapsulation(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); - - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - std::unique_ptr graph; - TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, - /*reuse_existing_functions=*/false, &graph, &library)); - - std::vector expected_nodes = { - "add1", "add2", "cluster1", "cluster1_parallel_check/_0", - "mul", "x1", "x2"}; - EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - - std::vector> expected_edges = { - {"add1:0", "add2:0"}, - {"add2:0", "cluster1_parallel_check/_0:0"}, - {"cluster1:0", "cluster1_parallel_check/_0:1"}, - {"cluster1_parallel_check/_0:0", "mul:1"}, - {"x1:0", "add1:0"}, - {"x1:0", "cluster1:0"}, - {"x1:0", "mul:0"}, - {"x2:0", "add1:1"}, - {"x2:0", "add2:1"}, - {"x2:0", "cluster1:1"}, - }; - EXPECT_EQ(expected_edges, GraphEdges(*graph)); -} - const Node* FindNodeByName(const Graph& graph, const string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; @@ -798,7 +757,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -814,7 +774,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); } @@ -843,7 +802,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -859,7 +819,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const // and another non-const, so overall non-const. diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5d211f4d733d8d807426e62dd116092799184f35..5b6692f523658749f7ef48f9d7d89e97d4ce8b09 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -cc_library( - name = "encapsulate_subgraphs_pass_flags", - srcs = ["encapsulate_subgraphs_pass_flags.cc"], - hdrs = ["encapsulate_subgraphs_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "mark_for_compilation_pass_flags", srcs = ["mark_for_compilation_pass_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc deleted file mode 100644 index 856475f12c8a411cd80c1c1859323304ca4029e0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static EncapsulateSubgraphsPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new EncapsulateSubgraphsPassFlags; - flags->tf_xla_parallel_checking = false; - flag_list = new std::vector({ - Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, - "Debug tool. Runs both JIT-compiled and interpreted graphs in " - "parallel and verifies they produce the same outputs."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h deleted file mode 100644 index d371bd269dbdfbf737d81490fb877fcf88661a8f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -typedef struct { - bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and - // interpreted graphs in parallel and verifies - // they produce the same outputs. -} EncapsulateSubgraphsPassFlags; - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74468266b9e983431732eafc801bc2d2ea682be9..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -44,12 +44,6 @@ namespace tensorflow { namespace { -// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward -// a ref tensor input to its output. -static bool AlwaysForwardsRefInput(const Node& node) { - return node.IsIdentity(); -} - bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -68,20 +62,8 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA does not offer guaranteed aliasing between the input and output of the // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. - if (AlwaysForwardsRefInput(node)) { - for (const Edge* incoming_edge : node.in_edges()) { - if (incoming_edge->IsControlEdge()) { - continue; - } - - Node* incoming_node = incoming_edge->src(); - if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { - VLOG(2) << "Not clustering " << node.def().ShortDebugString() - << " because of ref input " << incoming_node->name() << " " - << incoming_node->type_string(); - return false; - } - } + if (HasForwardedRefInput(node)) { + return false; } return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 70bd10336b824b4aaef6520f0b094f52e5a0d626..05b7821b8865d0f210ca9af92370e177d6043e80 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/device_name_utils.h" @@ -66,6 +67,9 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } return description; } + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + } // namespace Status DeviceToDeviceType(const string& device, DeviceType* device_type) { @@ -77,6 +81,24 @@ Status DeviceToDeviceType(const string& device, DeviceType* device_type) { return Status::OK(); } +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 5b673bdc27fccb4228b9e02cbf80d17aa35b5fe5..bcce082aaf6044ff0654efa4d78c0f493a350d00 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -36,6 +36,9 @@ using OrderedNodeSet = std::set; // Returns the DeviceType corresponding to 'device'. Status DeviceToDeviceType(const string& device, DeviceType* device_type); +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 0c49286acd3abaf8ea1f12a90d86a1d1ff38b234..11e45d2823da2b623bd3cd45f7147686b05fdb2f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -87,6 +88,46 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ REGISTER_KERNEL_BUILDER( \ Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ XlaAssignVariableOp); \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 96016521ea902274e3ec1dcc35d3d070063eb1ae..74257b09a808a39454eace3b1a9bf57a2e071360 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -178,6 +178,13 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, continue; } + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + compilation_candidates.insert(node); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index e6c92f9720e1285617280f60d1c5fea443c5ebef..98fab319d6f4fbf3159b6e8815baea262b882d2a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -51,6 +51,15 @@ py_library( ], ) +py_test( + name = "xla_test_test", + size = "small", + srcs = ["xla_test_test.py"], + deps = [ + ":xla_test", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 4dff5f0f405fb1d936ab2e6bcd82e05e926172c7..a4154ad1e846f8241a2ab6598da36ccb6b3b653e 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -31,11 +31,13 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest +from tensorflow.python.training import adam class EagerTest(XLATestCase): @@ -160,6 +162,114 @@ class EagerTest(XLATestCase): for _ in range(100): values.append(var.value()) + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() + + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testAdam(self): + with self.test_scope(): + optimizer = adam.AdamOptimizer(0.1) + x = resource_variable_ops.ResourceVariable(10.0) + with backprop.GradientTape() as tape: + y = x * x + dy_dx = tape.gradient(y, x) + optimizer.apply_gradients([(dy_dx, x)]) + self.assertAlmostEqual(9.9, x.numpy(), places=3) + + def testAdamSparse(self): + with ops.device('/cpu:0'): + # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates + # are not implemented on TPU. + embedding_matrix = resource_variable_ops.ResourceVariable( + array_ops.ones([3, 2])) + + with self.test_scope(): + with backprop.GradientTape() as tape: + embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) + y = math_ops.reduce_sum(embedding) + dy_dx = tape.gradient(y, embedding_matrix) + self.assertIsInstance(dy_dx, ops.IndexedSlices) + optimizer = adam.AdamOptimizer(0.1) + # The gradient application operations will run on CPU because optimizer + # updates are always collocated with the variable. + optimizer.apply_gradients([(dy_dx, embedding_matrix)]) + + # This assign_add will run on CPU because when an input to an + # operation is a resource, this operation is placed on the resource's + # device by the eager runtime. + embedding_matrix.assign_add(array_ops.ones([3, 2])) + + self.assertAllClose([[2.0, 2.0], + [1.9, 1.9], + [2.0, 2.0]], embedding_matrix.numpy()) + class EagerFunctionTest(XLATestCase): diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b..e610b63e301c75f532db1b58cd26533effea174d 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -201,6 +201,16 @@ class UnaryOpsTest(XLATestCase): expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) + with self.test_session() as session: + erf_x = session.run(math_ops.erf(x)) + erfc_x = session.run(math_ops.erfc(x)) + + self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x) + self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x) + self._assertOpOutputMatchesExpected( math_ops.exp, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index e924fe1e61454aefda622a5a46a0e483d26db5c1..88827cb53bee7bb809d0163d6badcef17e59aa78 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None, 'Value to set the TF_XLA_FLAGS environment variable to') +def parse_disabled_manifest(manifest_content): + comments_re = re.compile('#.*$') + disabled_tests = [] + disabled_method_types = [] + for l in manifest_content.splitlines(): + stripped = comments_re.sub('', l).strip() + if not stripped: + continue + entry = stripped.split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append((entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + disabled_regex = '|'.join(disabled_tests) + method_types_filter = dict() + for method, types in disabled_method_types: + method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types + ]) + return disabled_regex, method_types_filter + + class XLATestCase(test.TestCase): """XLA test cases are parameterized test cases.""" @@ -85,38 +111,21 @@ class XLATestCase(test.TestCase): # Parse the manifest file, if any, into a regex identifying tests to # disable - self.disabled_regex = None - self._method_types_filter = dict() # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. + self.disabled_regex = None + self._method_types_filter = {} if FLAGS.disabled_manifest is not None: - comments_re = re.compile('#.*$') - manifest_file = open(FLAGS.disabled_manifest, 'r') - disabled_tests = [] - disabled_method_types = [] - for l in manifest_file.read().splitlines(): - if not l: - continue - entry = comments_re.sub('', l).strip().split(' ') - if len(entry) == 1: - disabled_tests.append(entry[0]) - elif len(entry) == 2: - disabled_method_types.append( - (entry[0], entry[1].strip().split(','))) - else: - raise ValueError('Bad entry in manifest file.') - - self.disabled_regex = re.compile('|'.join(disabled_tests)) - for method, types in disabled_method_types: - self._method_types_filter[method] = set([ - dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype - for name in types]) - manifest_file.close() + with open(FLAGS.disabled_manifest, 'r') as manifest_file: + disabled_regex, self._method_types_filter = ( + parse_disabled_manifest(manifest_file.read())) + if disabled_regex: + self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py new file mode 100644 index 0000000000000000000000000000000000000000..24664451579445edaadb335c30d253ee55f003da --- /dev/null +++ b/tensorflow/compiler/tests/xla_test_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the XLATestCase test fixture base class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.platform import test + + +class XlaTestCaseTestCase(test.TestCase): + + def testManifestEmptyLineDoesNotCatchAll(self): + manifest = """ +testCaseOne +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne") + + def testManifestWholeLineCommentDoesNotCatchAll(self): + manifest = """# I am a comment +testCaseOne +testCaseTwo +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 8b9b026643cf35216a2082dfcce9270c017bd14f..d48c6eea754f75a8879d3938f233a6a591d26d0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); @@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind(); - OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); arg.type = resource->type(); arg.shape = resource->shape(); @@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); - inputs[i] = ctx->Input(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString(); } @@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); + bool has_tensor_array_gradients = false; for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { XlaResource* resource; @@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; } } + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + // Check that both branches have identical input shapes. OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); @@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index ebac5c4396f90f9cee5d900d3c34499677c1a02f..105be38fe26b6667e8b4ce6da92a3969cdc0c187 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -76,32 +76,14 @@ class RandomShuffleOp : public XlaOpKernel { ctx->SetOutput(0, input); } else { // Generate the random swaps for the indices. - auto zero = builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(xla::S32)), - gtl::ArraySlice({n})); - auto n_maxval = builder->Broadcast(builder->ConstantR0(n), - gtl::ArraySlice({n})); auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); - auto swaps = builder->RngUniform(zero, n_maxval, swaps_shape); + auto swaps = + builder->RngUniform(builder->ConstantR0(0), + builder->ConstantR0(n), swaps_shape); // Generate range(n) as the initial value for the indices to be swapped. - auto index_init_body_fn = [&](xla::XlaOp i, - gtl::ArraySlice loop_vars, - xla::XlaBuilder* builder) - -> xla::StatusOr> { - auto indices = loop_vars[0]; - i = builder->Reshape(i, {}, {1}); - // indices[i] = i - indices = builder->DynamicUpdateSlice(indices, i, i); - return std::vector{indices}; - }; - // for i in range(n): - xla::XlaOp index_zeros = Zeros(builder, swaps_shape); - auto index_init_loop_result = - XlaForEachIndex(n, xla::S32, index_init_body_fn, {index_zeros}, - "index_init_loop", builder) - .ValueOrDie(); - auto indices = index_init_loop_result[0]; + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); // Swap the indices at i and swaps[i]. auto swap_body_fn = [&](xla::XlaOp i, @@ -110,7 +92,7 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = builder->Reshape(i, {}, {1}); + i = builder->Reshape(i, {1}); // temp = indices[i] auto temp = builder->DynamicSlice(indices, i, {1}); // swap_index = swaps[i] diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..d59720bef742c7441ee01a954247013559bb909c 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 71a9fd051bfc8db09738a4bfe8ddde447895ecf0..2521445e86998cb027f94838650a049c9fd7e1a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -16,9 +16,11 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -185,5 +187,49 @@ XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); #undef XLAJIT_MAKE_UNARY +// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial +// is used outside of this range. +class ErfOp : public XlaOpKernel { + public: + explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::PrimitiveType primitive_type; + xla::XlaOp one = XlaHelpers::One(b, input_type(0)); + xla::XlaOp x = ctx->Input(0); + xla::XlaOp abs_x = b->Abs(x); + + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &primitive_type)); + + auto y = b->Select(b->Gt(abs_x, one), + b->Sub(one, ComputeErfc(b, x, primitive_type)), + ComputeErf(b, x, primitive_type)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erf"), ErfOp); + +class ErfcOp : public XlaOpKernel { + public: + explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp one = XlaHelpers::One(b, input_type(0)); + xla::XlaOp x = ctx->Input(0); + xla::XlaOp abs_x = b->Abs(x); + + xla::PrimitiveType primitive_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &primitive_type)); + + auto y = b->Select(b->Lt(abs_x, one), + b->Sub(one, ComputeErf(b, x, primitive_type)), + ComputeErfc(b, x, primitive_type)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erfc"), ErfcOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 526694d5a0c7124e1696f34b516f3b202462bc19..ee0bb91a6b747ffc9e28e19dd4869a5b2cc43501 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -71,8 +71,8 @@ xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(x_shape) || - xla::ShapeUtil::HasZeroElements(y_shape)) { + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 3f1384bc864abd882ebba2b90acbe0b1e664687a..20925118bf598a6436c43bd727ce40e3abafc46c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -110,7 +110,6 @@ xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] - auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); // select the whole i-th column, then mask out all rows above i+1 TF_ASSIGN_OR_RETURN( auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 43e1c1e9fecec1c71db1509757251cb5d903ca49..db56b128375ce8ff2faf12c5d7ea256bdfab0f63 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); + + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); + } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + + return Status::OK(); +} + Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 220bec15538c36fa30abef9e729b64dbbb9f72b3..74685025c1780c5c0ba56205a98786582e9191e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -29,6 +30,17 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); + // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . // Fails if the literal's primitive type != diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af09c7193f03b4685d3a7d4510d654dd..a1da176fe30ddd0d4460a51b60b2568ecc1af6aa 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,11 +19,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } @@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1b8e516770c3e217dd7c2f26ce426895b478c2e4..4525197146b7f29f405650bdb08e5946cbce8114 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -309,7 +309,6 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index dc69d2097ebe14ca0e14a39849d4fcae99024fdc..5c9abad4c3126be5e45e96c770c0679fe8606788 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -24,7 +24,8 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector service_instances; service_instances.reserve(computations.size()); for (const AotXlaComputationInstance& instance : computations) { @@ -36,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime( service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } - return compiler_service_->CompileAheadOfTime(service_instances, options); + return compiler_service_->CompileAheadOfTime(service_instances, options, + metadata); } int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index f9a7c31270c7a11175f47a537639a97d0c9211af..332c96503637344d56e363e19db4880c37ca9684 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -46,13 +46,15 @@ class CompileOnlyClient : public Client { const Shape* result_layout; }; - // Compiles a list of xla computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. + // Compiles a list of xla computations for ahead-of-time execution. + // This is intended for use in static compilation. The |options| + // parameter describes the target for which the compiler should emit + // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); + const AotCompilationOptions& options, + std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a1d34796ccfd86f2025eff0ecb51338eb6a9b1da..639f85737f0173f47d494f366b220ab60e09629e 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -121,4 +121,88 @@ StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { return builder->Reduce(predicates, f, logical_or, all_dimensions); } +namespace { +xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type, + float value) { + return b->ConvertElementType(b->ConstantR0(value), data_type); +} + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type) { + xla::XlaOp poly = FloatLiteral(b, data_type, 0.0); + for (float c : coefficients) { + poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c)); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { + xla::XlaOp zero = FloatLiteral(b, data_type, 0.0); + xla::XlaOp two = FloatLiteral(b, data_type, 2.0); + xla::XlaOp eight = FloatLiteral(b, data_type, 8.0); + + xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x)); + + xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type); + xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type); + xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type); + xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type); + + xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq), + b->Div(b->Mul(z, pr), ps)); + + return b->Select(b->Lt(x, zero), b->Sub(two, y), y); +} + +// Compute a polynomial approximation of the error function. +xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { + xla::XlaOp z = b->Mul(x, x); + xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type); + xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type); + return b->Div(b->Mul(x, pt), pu); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 64b6b7d63353165e45bf12d35126a7eeef9e56e4..f11cc003177c7eb68c32f9e618704a1ac7e63a73 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -55,6 +55,20 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Note: if predicates is zero-sized, Any() vacuously returns false. StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type); + +// Compute an approximation of the error function complement (1 - erf(x)). +xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); + +// Compute an approximation of the error function. +xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 5e17cc4dfb0b225712e94041970545ff19f03b98..d7ebcf8bebc1f656b4965c833e0d42ccceb1b99f 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1611,7 +1611,9 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, }); } -XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { return NoteErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1619,7 +1621,7 @@ XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); - return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{}, + return CrossReplicaSum(operand, computation, replica_group_ids, /*channel_id=*/tensorflow::gtl::nullopt); }); } @@ -1629,9 +1631,8 @@ XlaOp XlaBuilder::CrossReplicaSum( tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id) { return NoteErrorOrReturn([&]() -> StatusOr { - if (!replica_group_ids.empty() || channel_id.has_value()) { - return Unimplemented( - "replica_group_ids and channel_id and is not supported in AllReduce"); + if (channel_id.has_value()) { + return Unimplemented("channel_id is not supported in AllReduce"); } HloInstructionProto instr; @@ -1639,6 +1640,9 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } AddCalledComputation(computation, &instr); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 532cae014848e17b24ee720a3c3dc5f99c89dfe5..0329e42ed1aef8edd1537e888ddcd78f08584407 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -528,9 +528,12 @@ class XlaBuilder { tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding); - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - XlaOp CrossReplicaSum(const XlaOp& operand); + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a26b20c861846501c911253d89619591c37322b3 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -0,0 +1,18 @@ +# Description: +# Python API for shardings in XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "xla_sharding", + srcs = ["xla_sharding.py"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/python_api:types", + "//tensorflow/compiler/xla/python_api:xla_shape", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..abd10b164eaef8e75ed304483861baf250c5b954 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -0,0 +1,204 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Experimental support for defining XLA shardings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import xla_shape +from tensorflow.core.framework import attr_value_pb2 + + +class Sharding(object): + """A class to support adding sharding attributes to Ops. + + Use the factory constructors and then call apply_to_tensor: + Sharding.replicate().apply_to_tensor(tensor) + """ + + def __init__(self, proto=None): + """Do not use this constructor; use the factory functions below.""" + self._proto = proto + + @classmethod + def replicate(cls): + """Returns a replicated sharding attribute. + + This causes an op to be computed in its entirety independently on all + cores in the XLA device. + """ + return Sharding( + proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) + + @classmethod + def assign_device(cls, core): + """Returns an AssignDevice sharding attribute. + + This causes an op to be computed in its entirety only on one core in + the XLA device. + Args: + core: The core to assign this Op to. + """ + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.MAXIMAL, + tile_assignment_dimensions=[1], + tile_assignment_devices=[core])) + + @classmethod + def tile(cls, tile_shape, tile_assignment): + """Returns a Tiled sharding attribute. + + This causes an op to be partially computed on multiple cores in the + XLA device. + + Args: + tile_shape: A xla_shape.Shape describing the tile shape that each core + will compute. + The tile shape does not need to be divisible by the tile assignment. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type or tile_shape was + not of xla_shape.Shape type. + + TODO(jmolloy): This concept is nefarious and is not + something we really want to expose to users (especially as the + contract for tile_assignment is very strict). + """ + if not isinstance(tile_assignment, np.ndarray): + raise TypeError('Tile assignment must be of type np.ndarray') + if not isinstance(tile_shape, xla_shape.Shape): + raise TypeError('Tile shape must be of type xla_shape.Shape') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape.message, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices))) + + @classmethod + def split(cls, tensor, split_dimension, num_devices): + """Returns a Sharding that splits a tensor across a dimension. + + This creates a Tiled attribute, similar to tile(), but easier to use for the + common case of tiling a tensor N ways in one dimension. + + Args: + tensor: A tf.Tensor to split. + split_dimension: The dimension number to split. + num_devices: The number of cores to split `tensor` over. + + Raises: + ValueError: The tensor to split was smaller in the split dimension than + the number of devices to split over. + """ + tensor.shape.assert_is_fully_defined() + shape = tensor.shape.as_list() + if shape[split_dimension] < num_devices: + raise ValueError('Split dimension was smaller than the required number ' + 'of splits: shape=%r, dimension=%r, num_devices=%r', + shape, split_dimension, num_devices) + + tile_shape = shape + tile_shape[split_dimension] = int( + math.ceil(tile_shape[split_dimension] / num_devices)) + tile_shape_proto = xla_data_pb2.Shape( + element_type=xla_data_pb2.F32, dimensions=tile_shape) + + tile_assignment_dims = [1] * len(shape) + tile_assignment_dims[split_dimension] = num_devices + + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape_proto, + tile_assignment_dimensions=tile_assignment_dims, + tile_assignment_devices=range(num_devices))) + + def apply_to_tensor(self, tensor): + """Applies this Sharding attribute to `tensor`.""" + if len(tensor.op.outputs) > 1: + proto = self._get_or_create_tuple_proto(tensor.op) + # We can't mutate an element of old_proto.tuple_shardings, so create + # a new proto. + tuple_shardings = list(proto.tuple_shardings) + tuple_shardings[tensor.value_index] = self._proto + proto = xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) + else: + proto = self._proto + + attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) + # TODO(jmolloy): This need to be seriously revisited before declaring this + # API available for public use. + # pylint: disable=protected-access + tensor.op._set_attr('_XlaSharding', attr_value) + + @property + def proto(self): + """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" + return self._proto + + def _get_or_create_tuple_proto(self, op): + try: + attr = op.get_attr('_XlaSharding') + proto = xla_data_pb2.OpSharding() + proto.ParseFromString(attr) + return proto + except ValueError: + return self._create_tuple_proto(op) + + def _create_tuple_proto(self, op): + shardings = [ + xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) + for _ in op.outputs + ] + return xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) + + +# Helpers for the above factory functions that allow easy application of +# shardings, for example: +# tensor = xla_sharding.replicate(tensor) + + +def replicate(tensor): + Sharding.replicate().apply_to_tensor(tensor) + return tensor + + +def assign_device(tensor, device): + Sharding.assign_device(device).apply_to_tensor(tensor) + return tensor + + +def tile(tensor, tile_shape, tile_assignment): + Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) + return tensor + + +def split(tensor, split_dimension, num_devices): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) + return tensor diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index e8f29b83291a7cb238dc25b9f4bb743fe426a162..3f059cac30b5d36ab1d097bf200547533822e3d0 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -190,9 +190,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (!ShapeUtil::IsArray(shape)) { - return InvalidArgument( - "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + if (layout.minor_to_major_size() != 0 || + layout.padded_dimensions_size() != 0) { + return InvalidArgument( + "shape of primitive type %s should not have a non-trivial layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index bf9679cafec72c2e9dc5796e9058c6703239c508..2125ab7c61ab5e30fe51e16994e0da4883d509c4 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -606,8 +606,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } // namespace Status EqualShapes(const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return InvalidArgument("tupleness-mismatch! want: %s got %s", + if (expected.element_type() != actual.element_type()) { + return InvalidArgument("element type mismatch, want: %s got %s", ShapeUtil::HumanString(expected).c_str(), ShapeUtil::HumanString(actual).c_str()); } @@ -626,7 +626,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else { + } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected).c_str(), @@ -652,6 +652,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } } } + // Non-array, non-tuple shapes are trivially equivalent. return Status::OK(); } @@ -705,6 +706,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } break; } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); default: LOG(FATAL) << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 61afc311a702930a18be4842908f9a26b98d9a32..19e6d288c00a7a541e01390af4946c0caa06615e 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else { - CHECK(ShapeUtil::IsArray(shape)); + } else if (ShapeUtil::IsArray(shape)) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->set_buffer(new char[piece->size_bytes()]); } } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); } } @@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal( StridedCopy(data(), linear_index(shape(), dest_base), 0, src_literal.data(), linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(shape()) && - !ShapeUtil::HasZeroElements(src_literal.shape())) { + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { // Perform copy if neither src nor dest has dimensions with zero element, // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } +/* static */ std::unique_ptr Literal::CreateToken() { + return MakeUnique(ShapeUtil::MakeTokenShape()); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, tensorflow::gtl::ArraySlice src, const Shape& dest_shape, const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::HasZeroElements(dest_shape)) { + if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } std::vector index(ShapeUtil::Rank(dest_shape)); @@ -1177,7 +1184,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { return; } @@ -1368,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, return; } + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + if (LayoutUtil::IsSparseArray(subshape)) { pieces->push_back(shape_to_string(subshape)); pieces->push_back("{"); @@ -1556,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const { void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( @@ -1962,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const { // Empty shapes are not all the first element since there is no first // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { return false; } auto piece_is_all = [&]() { @@ -2341,28 +2353,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsArray(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(&shape_); + root_piece_.set_subshape(shape_.get()); } BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsNestedTuple(shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); - root_piece_.set_subshape(&shape_); - BuildPieceSubtree(shape_, &root_piece_); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_.tuple_shapes(i); + const auto& src_shape = shape_->tuple_shapes(i); CHECK(ShapeUtil::IsArray(src_shape)); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 1e26eb7ad4098bab1e757347a23edd73390b48b5..37ca8ea9f1d158b6bce8d5688288351f55c3b3c8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -917,6 +917,9 @@ class Literal : public LiteralBase { return MakeTupleOwned(std::move(v)); } + // Create a constant token literal. Token types have no value. + static std::unique_ptr CreateToken(); + // Returns a vector containing the tuple elements of this Literal as separate // Literals. This Literal must be tuple-shaped and can be a nested tuple. The // elements are moved into the new Literals; no data is copied. Upon return @@ -1099,8 +1102,10 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. - const Shape shape_; + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template @@ -1454,7 +1459,7 @@ void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices(ShapeUtil::Rank(shape()), 0); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index f127cee0fdc126429ed423aace3b3b7764a05b2e..493d807591dd3c425293e4ee796bca3036a3088c 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -334,6 +334,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { EXPECT_EQ(nil, nil); } +TEST_F(LiteralUtilTest, TokenEquality) { + auto token0 = Literal::CreateToken(); + auto token1 = Literal::CreateToken(); + auto scalar = Literal::CreateR0(1.0); + + EXPECT_EQ(*token0, *token1); + EXPECT_NE(*token0, *scalar); + + EXPECT_EQ(*Literal::MakeTuple({token0.get()}), + *Literal::MakeTuple({token0.get()})); + EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({scalar.get(), token1.get()})); +} + TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. auto colmajor = @@ -1431,7 +1447,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1443,7 +1459,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { EXPECT_EQ(literal.Get({2}), 3); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector one_two_three = {1, 2, 3}; const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 143c9a2366be5786b7ef2148580caeb97d67d2d8..b16147e3be71771269d8b7a18528bef3a8c72d99 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { } } +bool IsArrayType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2..889e9a1ceca675689406d255d348c82c398563aa 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type); bool IsIntegralType(PrimitiveType type); +// Returns true if values of the given primitive type are held in array shapes. +bool IsArrayType(PrimitiveType primitive_type); + // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index f808990cadeab5fd2c4857920ee1daaac7262edd..445cee1aa7b462f7ae2b6b0771ff57f0c8f3db99 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" namespace xla { - namespace swig { // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of @@ -97,6 +96,36 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; } +ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } + +LocalShapedBufferTuple::LocalShapedBufferTuple( + std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + DCHECK(element != nullptr); + } +} + +LocalShapedBufferTuple::~LocalShapedBufferTuple() { + for (LocalShapedBuffer* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr LocalShapedBufferTuple::Release(int i) { + LocalShapedBuffer* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int LocalShapedBufferTuple::size() const { return elements_.size(); } + static StatusOr ToBuffer(LocalClient* client, int device_ordinal, const Literal& arg) { @@ -598,10 +627,12 @@ _FORWARD_BINOP(Or) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -631,6 +662,54 @@ void DeleteLocalComputation(LocalComputation* computation) { delete computation; } -} // namespace swig +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer) { + if (!ShapeUtil::IsTuple( + local_shaped_buffer->shaped_buffer()->on_device_shape())) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString( + local_shaped_buffer->shaped_buffer()->on_device_shape()) + .c_str()); + } + DeviceMemoryAllocator* allocator = + local_shaped_buffer->shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + const Shape& tuple_shape = tuple_buffer.on_device_shape(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator))); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} + +} // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 9ac13b65231c932f152c1e79eb8e576cc6331fbd..0da3964676e9c6729229686f38bb05c8b2427bff 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { - namespace swig { // Initializes the number of replicas that XLA will be initialized with (when @@ -69,10 +68,42 @@ class LocalShapedBuffer { StatusOr > ToLiteral() const; + // Transfers ownership of the encapsulated ShapedBuffer to the caller, + // analogous to std::unique_ptr::release(). + ShapedBuffer Release(); + private: ScopedShapedBuffer shaped_buffer_; }; +// Result of a tuple destructuring operation on a LocalShapedBuffer -- this +// appears to be a simpler mechanism for the time being than an alternative like +// using SWIG to transform std::vectors into Python lists of SWIG objects +// directly. +class LocalShapedBufferTuple { + public: + // Note: any LocalShapedBuffer elements that are not Release()'d will be + // deallocated in the destructor. + explicit LocalShapedBufferTuple(std::vector elements); + + ~LocalShapedBufferTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements +// in LocalShapedBufferTuple form. +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer); + // Wraps a LocalExecutable produced by compiling a // LocalComputation. The Execute method forwards to that of the // underlying LocalExecutable, and additionally handles tranferring @@ -305,10 +336,12 @@ class LocalComputationBuilder { _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -336,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 536b93c6f9381ae5c84e65eb7ed264b5eb158a72..477df6fde25d0db760e08df9d335bd12e31ccb55 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -200,6 +200,20 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBufferTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); @@ -905,6 +919,9 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBufferTuple; +%unignore xla::swig::LocalShapedBufferTuple::Release; +%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; @@ -974,10 +991,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; @@ -989,6 +1008,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ReciprocalF32; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteCompiledLocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 11611ac61287da30548c335fac977bdc255396ed..c025127c3cf1871d4def1297ed36c046cae61d4b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -89,10 +89,12 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', @@ -184,6 +186,14 @@ class LocalBuffer(object): self._delete(self.c_local_shaped_buffer) self.c_local_shaped_buffer = None + def destructure(self): + assert self.c_local_shaped_buffer is not None + result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + size = result.size() + destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + return destructured + def is_deleted(self): return self.c_local_shaped_buffer is None diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 375e720f9b433f45ad5adc329104c286184a7510..71e1d60a4e23dbfef333223c396e109533da9365 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -365,6 +365,55 @@ class LocalBufferTest(LocalComputationTest): with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + def testDestructureTupleEmpty(self): + t = () + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 0) + + def testDestructureTupleOneArrayElement(self): + t = (np.array([1, 2, 3, 4], dtype=np.int32),) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 1) + array = pieces[0] + got = array.to_py() + want = NumpyArrayS32([1, 2, 3, 4]) + np.testing.assert_equal(want, got) + + def testDestructureTupleTwoArrayElementDifferentType(self): + t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32)) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) + + def testDestructureTupleNested(self): + t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + tuple0, array1 = pieces + got = array1.to_py() + want = NumpyArrayS32([5]) + np.testing.assert_equal(want, got) + got = tuple0.to_py() + self.assertEqual(type(got), tuple) + self.assertEqual(len(got), 2) + np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) + np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -571,6 +620,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -583,6 +638,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8999cda5ef852d1246bea45a3312575ec1ac0721 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -0,0 +1,36 @@ +# Description: +# Python API for XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "types", + srcs = ["types.py"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", + ], +) + +py_library( + name = "xla_shape", + srcs = ["xla_shape.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_library( + name = "xla_literal", + srcs = ["xla_literal.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b60f8dce92ace1b2c682374a2605b3a477936bbc --- /dev/null +++ b/tensorflow/compiler/xla/python_api/types.py @@ -0,0 +1,124 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Utilities for XLA-specific Python types.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 + +# Records corresponsence between a XLA primitive type and Python/Numpy types. +# +# primitive_type: value of type xla_data_pb2.PrimitiveType +# numpy_dtype: corresponsing Numpy "dtype" (like np.float32) +# literal_field_name: name of the field in the LiteralProto message elements +# of this type go into. +# literal_field_type: type of the field named 'literal_field_name'. +# +# TODO(eliben): figure out how to avoid knowing the extra Python type and the +# astype cast when writing into Literals. +TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ + 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type' +]) + +# Maps from XLA primitive types to TypeConversionRecord. +MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.F16: + TypeConversionRecord( + primitive_type=xla_data_pb2.F16, + numpy_dtype=np.float16, + literal_field_name='f16s', + literal_field_type=float), + xla_data_pb2.F32: + TypeConversionRecord( + primitive_type=xla_data_pb2.F32, + numpy_dtype=np.float32, + literal_field_name='f32s', + literal_field_type=float), + xla_data_pb2.F64: + TypeConversionRecord( + primitive_type=xla_data_pb2.F64, + numpy_dtype=np.float64, + literal_field_name='f64s', + literal_field_type=float), + xla_data_pb2.S8: + TypeConversionRecord( + primitive_type=xla_data_pb2.S8, + numpy_dtype=np.int8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.S16: + TypeConversionRecord( + primitive_type=xla_data_pb2.S16, + numpy_dtype=np.int16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.S32: + TypeConversionRecord( + primitive_type=xla_data_pb2.S32, + numpy_dtype=np.int32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.S64: + TypeConversionRecord( + primitive_type=xla_data_pb2.S64, + numpy_dtype=np.int64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.U8: + TypeConversionRecord( + primitive_type=xla_data_pb2.U8, + numpy_dtype=np.uint8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.U16: + TypeConversionRecord( + primitive_type=xla_data_pb2.U16, + numpy_dtype=np.uint16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.U32: + TypeConversionRecord( + primitive_type=xla_data_pb2.U32, + numpy_dtype=np.uint32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.U64: + TypeConversionRecord( + primitive_type=xla_data_pb2.U64, + numpy_dtype=np.uint64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.PRED: + TypeConversionRecord( + primitive_type=xla_data_pb2.PRED, + numpy_dtype=np.bool, + literal_field_name='preds', + literal_field_type=bool) +} + +# Maps from Numpy dtypes to TypeConversionRecord. +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +MAP_DTYPE_TO_RECORD = { + str(np.dtype(record.numpy_dtype)): record + for record in MAP_XLA_TYPE_TO_RECORD.values() +} diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py new file mode 100644 index 0000000000000000000000000000000000000000..b040098c294ffaae92b72f678947f99289239314 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA LiteralProto utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types +from tensorflow.compiler.xla.python_api import xla_shape + + +def ConvertLiteralToNumpyArray(literal): + """Converts a XLA literal to a Numpy array.""" + element_type = literal.shape.element_type + if element_type == xla_data_pb2.TUPLE: + return tuple( + ConvertLiteralToNumpyArray(subliteral) + for subliteral in literal.tuple_literals) + + type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] + if not literal.shape.dimensions: + return np.array( + getattr(literal, type_record.literal_field_name)[0], + type_record.numpy_dtype) + else: + # Infer the proper Numpy order from the LiteralProto's layout. The repeated + # field representing the array's content in the Literal is linearized. + # Reading is done in two steps: + # + # 1. Read the array as 1D from the LiteralProto repeated field. + # 2. Reshape the array to its proper shape, using the right order depending + # on the LiteralProto's layout. + layout_order = literal.shape.layout.minor_to_major + numpy_shape = tuple(literal.shape.dimensions) + if layout_order == range(len(literal.shape.dimensions)): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F') + elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') + else: + raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) + ndarray = np.array( + getattr(literal, type_record.literal_field_name), + copy=False, + dtype=type_record.numpy_dtype) + return numpy_reshaper(ndarray) + + +def _ConvertNumpyArrayToLiteral(ndarray): + """Converts a Numpy array to a XLA literal.""" + type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)] + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message) + + if ndarray.ndim == 0: + getattr(literal, type_record.literal_field_name).append( + np.asscalar(ndarray.astype(type_record.literal_field_type))) + else: + # Ndarrays with boolean dtypes need special type conversion with protobufs + if ndarray.dtype in {np.bool_, np.dtype('bool')}: + for element in np.nditer(ndarray): + getattr(literal, type_record.literal_field_name).append( + type_record.literal_field_type(element)) + else: + ndarray_flat = ndarray.ravel(order='A') + getattr(literal, type_record.literal_field_name).extend(ndarray_flat) + return literal + + +def ConvertNumpyArrayToLiteral(value): + """Converts a Numpy array or a nested tuple thereof to an XLA literal.""" + if isinstance(value, tuple): + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message) + for component in value: + component_literal = literal.tuple_literals.add() + component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component)) + return literal + else: + return _ConvertNumpyArrayToLiteral(value) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA Shape utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types + + +class Shape(object): + """Wraps a xla_data_pb2.Shape message with a convenient Python type. + + Provides direct access to the underlying xla_data_pb2.Shape message in the + message attribute, along with accessor wrappers to the message's fields. + Avoid direct access to .message unless interacting directly with protobuf APIs + like CopyFrom. In other words, prefer hauling the shape around in a Shape, and + only access .message when strictly required by the protobuf API. + """ + + def __init__(self, element_type, dimensions, layout=None): + """Creates a new XLA Shape. + + Args: + element_type: element type from xla_data_pb2. + dimensions: sequence of dimensions sizes (integers), or sequence + of Shapes in the case of a tuple, i.e. when element_type is + TUPLE. + layout: optional minor_to_major sequence for layout. If not given, the + default major-to-minor layout is used. + + Raises: + ValueError: if element_type is TUPLE but dimensions are not Shape objects. + """ + self.message = xla_data_pb2.Shape() + self.message.element_type = element_type + if element_type == xla_data_pb2.TUPLE: + if not all(isinstance(subshape, Shape) for subshape in dimensions): + raise ValueError( + 'XLA tuple requires sequence of Shape objects as dimensions') + self._tuple_shapes = tuple(dimensions) + for component_shape in self._tuple_shapes: + component_message = self.message.tuple_shapes.add() + component_message.CopyFrom(component_shape.message) + else: + self.message.dimensions.extend(dimensions) + if layout is None: + layout = list(reversed(range(len(dimensions)))) + self.message.layout.format = xla_data_pb2.DENSE + self.message.layout.minor_to_major.extend(layout) + + def element_type(self): + return self.message.element_type + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?') + return self.message.dimensions + + def tuple_shapes(self): + """If this is a tuple, returns its sequence of constituent Shape objects. + + Returns: + Tuple sub-shapes. + + Raises: + ValueError: if this is not a tuple. + """ + if not self.is_tuple(): + raise ValueError('tuple_shapes() called on a non-tuple shape') + return self._tuple_shapes + + def layout(self): + return self.message.layout + + @staticmethod + def from_pyval(pyval): + return CreateShapeFromNumpy(pyval) + + +def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name + """Create a Shape from a given Numpy array. + + Args: + ndarray: Numpy array. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type + dimensions = ndarray.shape + + # Set the shape's layout based on the ordering of ndarray. + # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). + if np.isfortran(ndarray): + # Column-major layout. This corresponds to a "dimension order is + # minor-to-major" layout in XLA. + layout = range(ndarray.ndim) + else: + # Row-major layout. This corresponds to a "dimension order is + # major-to-minor" layout int XLA. + layout = list(reversed(xrange(ndarray.ndim))) + + return Shape(element_type, dimensions, layout) + + +def CreateShapeFromNumpy(value): # pylint: disable=invalid-name + """Create a Shape from a Numpy array or a nested tuple structure thereof. + + Args: + value: Numpy array or (possibly nested) tuple structure that bottoms out in + Numpy arrays. + + Returns: + A Shape object. + """ + if isinstance(value, tuple): + return Shape( + xla_data_pb2.TUPLE, + [CreateShapeFromNumpy(component) for component in value]) + else: + return _CreateShapeFromNumpy(value) + + +def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name + """Create a shape from a Numpy dtype and a sequence of nonnegative integers. + + Args: + dtype: a numpy dtype, e.g. np.dtype('int32'). + shape_tuple: a sequence of nonnegative integers. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type + return Shape(element_type, shape_tuple) diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0d56a9a477b15964ad45e798865aa8d2c7385073..0b1cec1925d4424db086f8a3f62c91ede090189c 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -39,10 +39,10 @@ tf_cc_binary( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++_unsecure", ], ) @@ -54,6 +54,7 @@ tf_cc_test( ], deps = [ ":grpc_stub", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -61,7 +62,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++_unsecure", ], ) @@ -71,9 +71,9 @@ cc_library( hdrs = ["grpc_service.h"], deps = [ ":xla_service_proto", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 313f11a9a957155eb277dc02ba5d2565c87e0235..d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "grpc++/create_channel.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 5cd573167ae8c002ad8f09e8ba3fb25c6f356564..ca1b09b648013ad45d806040c5ddcf11d9e5604e 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ #define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ -#include "grpc++/server_context.h" +#include "grpcpp/server_context.h" #include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index e29908ccec80db76e3b5b856e57382c56430c379..c68c857c304138ff4318e243f66547c6acce1005 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -15,9 +15,9 @@ limitations under the License. // Basic server binary that exposes a xla::Service through a GRPC interface // on a configurable port. -#include "grpc++/security/server_credentials.h" -#include "grpc++/server.h" -#include "grpc++/server_builder.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 92eb19ec0f9696974556be01a93c074846f6c23a..551ae895e05586daec0ffcd425f4950f76bdd50d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -115,10 +115,6 @@ service XlaService { returns (ComputeConstantResponse) { } - // Retrieves the inferred shape for a value within a computation. - rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { - } - // Requests one or more device handles from the target. The returned device // handles can be used to specify the device on which to execute computations // or transfer data. @@ -132,18 +128,6 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Requests that the referenced computation be specialized for the provided - // arguments for subsequent execution. This permits things such as value - // specialization. - rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { - } - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 20cc671ba393769afa1dd2c964197a87c1835504..8a1d1bf73d51d81f6a9cf353c0bd0591231f5225 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -292,7 +292,6 @@ cc_library( ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -401,17 +400,6 @@ tf_cc_test( ], ) -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], @@ -591,7 +579,6 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -606,7 +593,6 @@ cc_library( ":platform_util", ":source_map_util", ":transfer_manager", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -641,7 +627,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -762,7 +747,6 @@ cc_library( ":hlo_proto", ":pool", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -864,7 +848,6 @@ cc_library( hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1118,6 +1101,7 @@ tf_cc_test( srcs = ["hlo_scheduling_test.cc"], deps = [ ":buffer_value", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_scheduling", @@ -1165,6 +1149,19 @@ tf_cc_test( ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_creation_utils", srcs = ["hlo_creation_utils.cc"], @@ -1646,7 +1643,6 @@ tf_cc_test( ":hlo_cost_analysis", ":local_service", ":service", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", @@ -1987,20 +1983,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -2142,6 +2124,7 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", + ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", @@ -2149,6 +2132,7 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2162,6 +2146,7 @@ tf_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":flatten_call_graph", ":hlo", ":hlo_matchers", ":hlo_ordering", @@ -2171,6 +2156,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -2397,7 +2383,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) @@ -2561,7 +2546,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2592,6 +2576,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index dc5f1b31bf8510be404491b7bceb36f73f4cbf75..1fc8fb9b6994db78fe3aa06e1ea790decfce7b97 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -449,7 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( // Filter out and remove empty operands. std::vector nonempty_operands; for (HloInstruction* operand : operands) { - if (!ShapeUtil::HasZeroElements(operand->shape())) { + if (!ShapeUtil::IsZeroElementArray(operand->shape())) { nonempty_operands.push_back(operand); } } @@ -1058,9 +1058,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(dot->shape()) || + ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( @@ -1392,7 +1392,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { - if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); @@ -1638,7 +1638,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Reshape directly to empty constant if the shape contains zero-element // dimension. - if (ShapeUtil::HasZeroElements(reshape->shape())) { + if (ShapeUtil::IsZeroElementArray(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( Literal::CreateFromShape(reshape->shape())); @@ -1739,7 +1739,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( // If any dimension of update is 0, elide the DynamicUpdateSlice. This // optimization becomes invalid should we later prefer to warn about out of // bound indices. - if (ShapeUtil::HasZeroElements(update->shape())) { + if (ShapeUtil::IsZeroElementArray(update->shape())) { return ReplaceInstruction(dynamic_update_slice, dynamic_update_slice->mutable_operand(0)); } @@ -1751,8 +1751,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto init_value = reduce->mutable_operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - if (ShapeUtil::HasZeroElements(arg->shape()) || - ShapeUtil::HasZeroElements(reduce->shape())) { + if (ShapeUtil::IsZeroElementArray(arg->shape()) || + ShapeUtil::IsZeroElementArray(reduce->shape())) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); @@ -1783,6 +1783,37 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1832,7 +1863,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { - if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcast(reduce_window->shape(), @@ -2028,8 +2059,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { return ReplaceWithNewInstruction( convolution, HloInstruction::CreateBroadcast( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cda157f9fac1639d792fb55b5a5ddac56df271aa..2605b0488cb7c6850746df94c4ab05d6b5d35de5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1714,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1759,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1781,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1804,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1932,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2060,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2090,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2121,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2151,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2184,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2200,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2219,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2237,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2259,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2268,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2349,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2444,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 598718c72c6941a4859063ed894c45b9c620998e..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} HloComputation* GetOrCreateScalarAddComputation( PrimitiveType primitive_type) { - HloComputation** scalar_add_computation = - &scalar_add_computations_[primitive_type]; - if (*scalar_add_computation) { - return *scalar_add_computation; - } - HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(primitive_type, {}); auto scalar_lhs = b.AddInstruction( @@ -93,71 +85,38 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - *scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_add_computation; - } - - // TODO(b/80534766): Remove maps after performance issues with scalar - // broadcasts are resolved on all backends. - HloComputation* GetOrCreateScalarRsqrtComputation( - PrimitiveType primitive_type) { - HloComputation** scalar_rsqrt_computation = - &scalar_rsqrt_computations_[primitive_type]; - if (*scalar_rsqrt_computation) { - return *scalar_rsqrt_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(-0.5f))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kPower, scalar_lhs, scalar_rhs)); - *scalar_rsqrt_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_rsqrt_computation; + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - std::unique_ptr Rsqrt(HloInstruction* operand) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarRsqrtComputation(operand->shape().element_type())); - } - - HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type, - int64 element_count) { - HloComputation** scalar_mean_computation = - &scalar_mean_computations_[std::pair( - primitive_type, element_count)]; - if (*scalar_mean_computation) { - return *scalar_mean_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0( - 1.0f / static_cast(element_count)))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs)); - *scalar_mean_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_mean_computation; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); } - std::unique_ptr Mean(int64 element_count, - HloInstruction* operand) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarMeanComputation(operand->shape().element_type(), - element_count)); + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); } // Replaces the existing HLO instruction old_instruction, with @@ -189,18 +148,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; // Whether rewrite has occurred. bool changed_ = false; - - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap - scalar_add_computations_; - tensorflow::gtl::FlatMap - scalar_rsqrt_computations_; - tensorflow::gtl::FlatMap, HloComputation*> - scalar_mean_computations_; }; } // namespace @@ -208,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -290,28 +239,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. auto mean_square = @@ -329,7 +264,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, @@ -431,7 +366,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, @@ -545,10 +480,12 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( // rsqrt[Var[X] + epsilon]. auto rsqrt_var_add_epsilon_broadcasted = add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon_activation))); + variance_broadcasted, epsilon_activation), + add)); auto rsqrt_var_add_epsilon = add(Rsqrt( - add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature))); + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. auto activation_minus_mean = add_binary( @@ -573,21 +510,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, @@ -616,8 +538,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = - add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = Literal::CreateR0(elements_per_feature_int64); @@ -665,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 7fd1e733e96da95cf43d9861af6d48a1850051c8..f7b4c1405dbc8719d8fba5476e6e41d2921ea877 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum)); + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 9926661dd30600b2bf20e7f137aa50d9fbfd7c82..830f26422bdc2b3bd789e7d5926bcebac815d34a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -250,8 +250,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, - reduction)); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ed0746980f87ac2bea79c308644dc63769f9e309..8f1d2f0804960b04dbff4c990c356589a609ce8d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -631,7 +631,7 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { subshape, converted_outputs.element(parent_index), output_index.back())); } - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { continue; } if (!ShapeUtil::Compatible( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 682c3865797c85eedf3949738f3372857f146c0e..afe4b2e1425f9e84320ffd5f08beceaac8168c22 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -633,7 +633,7 @@ Status BufferAssignment::ComputeSummaryStats() { if (module_sequence.size() == module_->computation_count()) { TF_ASSIGN_OR_RETURN( const int64 min_size, - MinimumMemoryForSequence(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 7e86c33687e595ad154361dd7018791299cc56ab..efa4696130ffeff669b0d674438a45c5a9d48ef2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -371,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -418,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -477,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -547,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -601,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -654,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -818,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1496,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -1536,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { // be {%rev, %neg, %concat}. This occurs right at the concat itself. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto log = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); auto rev = builder.AddInstruction( @@ -1673,7 +1673,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, @@ -2103,7 +2103,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 52f33a1318e91d3f5941a5d68051e4c207661bbc..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d8fdccf9bbf1c1788bb4000aa702292362446503..7426672a7a2a9102bd5ea98bd51092982e1e09b4 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -63,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_program_shape()); @@ -100,7 +101,8 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, + metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index e6a66c202d6e0df3cb6d165e51beb25abd8ec45c..1ac950bdd66bd034dfdafa8598ec506221e99c2f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -53,6 +53,12 @@ class CompileOnlyService : public Service { const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6f06bba6798bdff51f10d8fe9dc524d8064ba849..0dceed853dcbae211657f00433866cfe10c51fc7 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -35,6 +35,20 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo, return {}; } +// Define a default version where metadata is not used. +StatusOr>> +Compiler::CompileAheadOfTime( + std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + if (metadata != nullptr) { + return Unimplemented( + "Populating AotCompilationMetadata is not implemented on this " + "compiler."); + } + return CompileAheadOfTime(std::move(modules), options); +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 6c52ffd800d19de83877341d41ef81eee2de7251..d1144f97bb2ab29d3d18f3b3f65a38af46e68dd1 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -94,6 +94,19 @@ class AotCompilationOptions { DebugOptions debug_options_; }; +// Abstract superclass describing metadata produced during ahead-of-time +// compilation. +class AotCompilationMetadata { + public: + AotCompilationMetadata(const AotCompilationMetadata&) = delete; + AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; + + virtual ~AotCompilationMetadata() = default; + + protected: + AotCompilationMetadata() = default; +}; + // Abstract compiler interface that is subclassed for compilation on a // particular platform. // @@ -172,6 +185,13 @@ class Compiler { CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) = 0; + // Similar to CompileAheadOfTime above but AotCompilationMetadata + // has an argument that can be populated during compilation. + virtual StatusOr>> + CompileAheadOfTime(std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + ///// // The Compiler class also serves as a point to register compiler objects // for the various platforms. diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 53c3a3f7b738687db3098acfaef1ae87860d0440..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,12 +32,21 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the // ProgramShape are ignored if ignore_layouts is true. explicit ComputationLayout(const ProgramShape& program_shape, bool ignore_layouts = true); + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } + // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { return parameter_layouts_[param_no]; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 33d8338809d4e8c7c4774f062c3dda5494543ca6..e0ce2e3555e7746d6df212123fe1f968937cceed 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -472,6 +472,10 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { + if (ShapeUtil::IsToken(value_a->shape())) { + // Token values have no representation and cannot interfere. + continue; + } for (const HloValue* value_b : buffer.values()) { if (value_a != value_b) { DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, @@ -613,7 +617,10 @@ class CopyRemover { VLOG(2) << copy->name() << " is not removable"; return false; } - + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } const CopyNodes& copy_node = copy_map_.at(copy); ValueNode* src = copy_node.src; ValueNode* dest = copy_node.dest; @@ -947,28 +954,6 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); - } - } - } - return Status::OK(); -} - // Add copies to address special constraints on the roots of computations not // related to live range interference: // @@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { HloInstruction* instruction = pair.first; const ShapeTree& indices_to_copy = pair.second; + ShapeTree copies_added(indices_to_copy.shape()); std::vector users = instruction->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } + // Special case copies are not eligible for later copy elision passes. + indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { + if (has_copy) { + HloInstruction* copy = *copies_added.mutable_element(index); + if (copy != nullptr) { + copy->SetCopyElisionAllowed(false); + } + } + }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + MaybeDumpModule("after adding copies to resolve interference", *module); + + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id()) && + instruction->CopyElisionAllowed()) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } + } + MaybeDumpModule("after removing unnecessary copies", *module); + + return Status::OK(); +} + StatusOr CopyInsertion::Run(HloModule* module) { // Copy insertion is performed in three steps: // @@ -1158,14 +1178,10 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - MaybeDumpModule("after adding copies to resolve interference", *module); - DependencyHloOrdering ordering(module); TF_RETURN_IF_ERROR( RemoveUnnecessaryCopies(ordering, existing_copies, module)); - MaybeDumpModule("after removing unnecessary copies", *module); - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); MaybeDumpModule("after adding special-case copies", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 65e3d31e347e2cb249a072e7d06ca10c55401748..0d7b3c20f982cae21e5160fe5be20c85bf940ed7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -64,6 +64,13 @@ class CopyInsertion : public HloPassInterface { static StatusOr AddCopiesForBufferAssignment(HloModule* module); }; +// Try to remove as many copies from the module as possible without introducing +// live range interference. Copy instructions (identified by their unique id) in +// the set copies_to_exclude are not considered for removal. +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..ed1a50f516ee23e0f034bf5c2ed15fac7a70c3cc 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1595,6 +1595,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { EXPECT_THAT(condition->root_instruction(), op::Constant()); } +TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { + string module_string = R"( +HloModule TokensShouldNotBeCopied + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %generate-token = token[] generate-token(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokensShouldNotBeCopied () -> s32[] { + %one = s32[] constant(1) + %negative_one = s32[] negate(%one) + %init_token = token[] generate-token() + %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloRunner::CreateModuleFromString( + module_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should be no copies added because tokens should not be copied. + EXPECT_EQ(CountCopies(*module), 0); +} + std::unique_ptr MakeTrivialCondition(const Shape& shape) { auto builder = HloComputation::Builder("trivial_condition"); builder.AddInstruction( @@ -1636,8 +1675,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1715,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1787,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 278bb1bebfa1a0d76d0268b6b6fcfa87410ceee5..b703be0f39e2032bc58479f0b957f9d8b01a77c3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -151,7 +151,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -898,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 25b18eff20f901fc34343a12bfbd353ecec49cfb..d039132535071661d047579587385210719fede3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -550,8 +549,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -730,7 +729,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index cda623f8e87a3bce7df824d89d863616413b89c6..e8b205051e2828b8f1d3ecd2161ae9d53d3f1796 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -324,11 +324,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { int64 column_remainder = k() % tile_cols(); int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -341,19 +341,20 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( int64 columns, bool is_first_column) { int64 row_limit = m() - (m() % tile_rows()); - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -372,7 +373,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, @@ -382,7 +383,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( @@ -390,7 +391,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -571,9 +572,10 @@ void RowMajorMatrixVectorProductEmitter::Emit() { int64 row_remainder = m() % tile_rows(); int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -585,17 +587,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( std::vector* vector_accumulators) { int64 column_limit = k() - (k() % tile_cols()); - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -612,14 +614,15 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } @@ -740,7 +743,7 @@ class MatrixMatrixBlockPanelEmitter { private: // The HandleResiduesOnX helpers split the iteration space for dimension X // into a multiple of the tile size on dimension X and an epilogue. These - // helpers ultimately call into `EmitTiledReductionLoop` for emitting the + // helpers ultimately call into `EmitTiledGemm` for emitting the // tiled GEMM kernel. void HandleResiduesOnN(); @@ -750,15 +753,13 @@ class MatrixMatrixBlockPanelEmitter { llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end); - // This emits the inner reduction loop. This inner reduction loop multiplies - // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the - // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size - // [`tile_size_m`, vls->vector_width()] in the result. - void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } @@ -819,7 +820,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { if (n_start != dims().n()) { VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); - ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); @@ -848,16 +849,24 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, - tile_size_m(), GetInt64(0), GetInt64(m_end)); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); if (m_end != dims().m()) { - EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), - GetInt64(dims().m())); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); } } +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// // The tiling scheme is as follows: // // Let the LHS be: @@ -919,41 +928,48 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( // +-------------------+-------------------+-------------------+--------- // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... // +-------------------+-------------------+-------------------+--------- -void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop( +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - - ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - ksl_.For( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_memory_tile.LoadTile(n_i); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_memory_tile.StoreTile(result_tile, n_i); - }); - }); - }); + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); } } // namespace @@ -1285,8 +1301,11 @@ Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. @@ -1608,8 +1627,8 @@ bool PotentiallyImplementedAsEigenDot( const Shape& lhs_shape = hlo.operand(0)->shape(); const Shape& rhs_shape = hlo.operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { + if (ShapeUtil::IsZeroElementArray(lhs_shape) || + ShapeUtil::IsZeroElementArray(rhs_shape)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 2effb7fc360a47bf780fcbf9b6c9a096cb1cf41e..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -144,8 +144,12 @@ class DotOpEmitter { } std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. const std::tuple kDefaultTileSize = - std::tuple(3, 5, 1); + std::tuple(11, 9, 1); return options::LlvmIrGemmTileSize(hlo_module_config_) .value_or(kDefaultTileSize); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b560b7531c0d24e6f670e61a15dce295d9fa2a49..1a8bedfe6afb4f096ddd4703c312b84d521a7ba5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution( return false; } - if (ShapeUtil::HasZeroElements(input_shape) || - ShapeUtil::HasZeroElements(kernel_shape)) { + if (ShapeUtil::IsZeroElementArray(input_shape) || + ShapeUtil::IsZeroElementArray(kernel_shape)) { return false; } // Make sure input and kernel has the same data type. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 59223fddac2f5f7e2e85de4d37e4b6c5760ae697..758b8c62b4800215caae82208454ac971807f6eb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -226,10 +226,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else { - // Use the elemental emitter for non-tuple shapes. + } else if (ShapeUtil::IsArray(copy->shape())) { + // Use the elemental emitter for array shapes. return DefaultAction(copy); } + return Unimplemented( + "unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type()).c_str()); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -1873,7 +1876,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); - if (ShapeUtil::HasZeroElements(slice->shape())) { + if (ShapeUtil::IsZeroElementArray(slice->shape())) { return Status::OK(); } @@ -2528,6 +2531,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } +Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) { + TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); + // No code to generate, but we need to emit an address for book-keeping. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2809,7 +2819,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = compute_function_->result_arg(); - if (!ShapeUtil::IsNil(target_shape)) { + if ((ShapeUtil::IsArray(target_shape) && + !ShapeUtil::IsZeroElementArray(target_shape)) || + (ShapeUtil::IsTuple(target_shape) && + !ShapeUtil::IsEmptyTuple(target_shape))) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 32c536e18fee86cc60067ba3b25ab1eb0e4233df..e1815c1db7a14dfc90ff646c0fd1e439ffffb2e8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -150,6 +150,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleGenerateToken(HloInstruction* gen_token) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index cd1165e23812861ba9951546b7dd744529232196..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index edcaec584997b17dce30b8c46fda4abc78441064..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -317,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 64678d9d7450974f68817f92526519697a83683c..ee2b455730f8f520db6652f0352f8a96291cac73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -243,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62f5cee4f61b3c36b5e8f653cfd6db8e..6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9a8bab353ef6b1e0b05b250d35296bc3cef8bc37..93fea7ead7a86bb34c449668fd88a58145681eb1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -456,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 087bd1432945abfd860fcb8b1e92dd419598e025..dc1f26ea65cc707d4f0522af2aa3ec40621632f1 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -131,12 +130,6 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. const Shape& host_result_shape() const { diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 2d3e4b1fcdf6675955714cab262a8b2ca8ff4297..7cd2c9c136acac46e8e6c548c9e58b9bc8e6e0d2 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -300,7 +300,7 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { - CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); @@ -369,7 +369,7 @@ StatusOr GatherExpander::Run(HloModule* module) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::HasZeroElements(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()); }; std::vector gather_instrs; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 5ee67ccb4ae147683c7b41941670c6fc413a0d09..d9f62c21c4ef932bb61f2f9e0f7a318366ce94f0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -74,7 +74,7 @@ GenericTransferManager::TransferLiteralFromDevice( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (!ShapeUtil::IsTuple(subshape)) { + if (ShapeUtil::IsArray(subshape)) { TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, /*source=*/device_buffer.buffer(index), diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6bd9d4c31df5b76820abcb711f910b7c468c057d..541a5275a384ebfd900be086216c6d0c6958cd88 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -164,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -236,6 +237,19 @@ cc_library( ], ) +cc_library( + name = "hlo_execution_profiler", + srcs = ["hlo_execution_profiler.cc"], + hdrs = ["hlo_execution_profiler.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:pool", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "gpu_executable", srcs = [ @@ -277,6 +291,7 @@ cc_library( ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", + ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -422,6 +437,34 @@ tf_cc_test( ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gpu_copy_insertion", srcs = ["gpu_copy_insertion.cc"], @@ -522,6 +565,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index e0c73aa73acb7f3313eb54fb07390cb76590433e..f9dccd287d955502858f6c24ccd4de80256fc148 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { } // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || - ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index e5e2a0478a0659986ddec8d6785827b14b9efb56..b812dd7d3fbb25f279e87f79b647e299f29073ea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b85721980715e2ce2cd7a689ab12a6cea55ba3f1..9d66648a402fb82c35e0bf3ea1179f7995ed7c76 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -159,13 +160,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { pass.AddPass(); } - // TODO(kramerb): Remove use_fusion once instruction fusion can create - // multi-output fusions from the unfused expander output. pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/true); + /*rewrite_grad_op=*/true); // Rewrite gather ops into smaller ones. pass.AddPass(); @@ -261,6 +259,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); + fusion.AddPass(/*is_layout_sensitive=*/true, + /*only_fusion_computations=*/true); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 25d8f720ea4791a4c94efcad6909cd0c113fbe70..f20a828bc1a31ad15298a1d77cd79599aa12faf4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -41,77 +41,6 @@ namespace { using tensorflow::tracing::ScopedAnnotation; -// A helper class for profiling HLO in the course of GPU program execution. -// All of the profiling is guarded internally, to avoid the caller needing to -// have lots of conditionals sprinkled around. -class HloExecutionProfiler { - public: - // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation) - : do_profile_(do_profile), - profile_(profile), - stream_(stream), - sub_streams_(sub_streams), - computation_(computation) { - if (do_profile_) { - clock_rate_ghz_ = - stream->parent()->GetDeviceDescription().clock_rate_ghz(); - execution_timer_.reset(new se::Timer(stream->parent())); - per_op_timer_.reset(new se::Timer(stream->parent())); - stream->InitTimer(execution_timer_.get()) - .ThenStartTimer(execution_timer_.get()); - stream->InitTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, sets the total cycle count on the profile from the - // execution timer. - void FinishExecution() { - CHECK(!finished_execution_) << "Call FinishExecution only once!"; - finished_execution_ = true; - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->set_total_cycles_executed( - *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - // If profiling is enabled, starts the per-operation timer. - void StartOperation() { - if (do_profile_) { - stream_->ThenStartTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, stops the per-operation timer and records the time - // that the hlo_instruction took to execute in the profile. - void FinishOperation(const HloInstruction* hlo_instruction) { - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->SetCyclesTakenBy( - hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - private: - const bool do_profile_; - double clock_rate_ghz_; - HloExecutionProfile* profile_; - se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; - const HloComputation* computation_; - std::unique_ptr execution_timer_; - std::unique_ptr per_op_timer_; - bool finished_execution_ = false; -}; - } // namespace // Implementation note: HLO profiling is always enabled for GPU executables, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..daddd3738e4bb54f3695a96f6f9ffb9accabe97c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +HloExecutionProfiler::HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) + : do_profile_(do_profile), + profile_(profile), + stream_(stream), + sub_streams_(sub_streams), + computation_(computation) { + if (do_profile_) { + clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); + execution_timer_.reset(new se::Timer(stream->parent())); + per_op_timer_.reset(new se::Timer(stream->parent())); + stream->InitTimer(execution_timer_.get()) + .ThenStartTimer(execution_timer_.get()); + stream->InitTimer(per_op_timer_.get()); + } +} + +void HloExecutionProfiler::FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; + if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); + stream_->ThenStopTimer(execution_timer_.get()); + stream_->BlockHostUntilDone().IgnoreError(); + profile_->set_total_cycles_executed( + *computation_, + static_cast(execution_timer_->Nanoseconds() * clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartOperation() { + if (do_profile_) { + stream_->ThenStartTimer(per_op_timer_.get()); + } +} + +void HloExecutionProfiler::FinishOperation( + const HloInstruction* hlo_instruction) { + if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); + stream_->ThenStopTimer(per_op_timer_.get()); + stream_->BlockHostUntilDone().IgnoreError(); + profile_->SetCyclesTakenBy( + hlo_instruction, + static_cast(per_op_timer_->Nanoseconds() * clock_rate_ghz_)); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c9b882ff805c45a57f15df4fe79dc34100c0ceff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A helper class for profiling HLO in the course of GPU program execution. +// All of the profiling is guarded internally, to avoid the caller needing to +// have lots of conditionals sprinkled around. +class HloExecutionProfiler { + public: + // If profiling is enabled, start an execution timer running. + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation); + + // If profiling is enabled, sets the total cycle count on the profile from the + // execution timer. + void FinishExecution(); + + // If profiling is enabled, starts the per-operation timer. + void StartOperation(); + + // If profiling is enabled, stops the per-operation timer and records the time + // that the hlo_instruction took to execute in the profile. + void FinishOperation(const HloInstruction* hlo_instruction); + + private: + const bool do_profile_; + double clock_rate_ghz_; + HloExecutionProfile* profile_; + se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; + const HloComputation* computation_; + std::unique_ptr execution_timer_; + std::unique_ptr per_op_timer_; + bool finished_execution_ = false; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index f766f968826d960a8e86308f2395301aaa09f1ae..375709150e08996ea6a40f5e9e66a8f8d9287008 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -199,7 +199,7 @@ StatusOr> HloSchedule::Build( // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, - CreateMemoryMinimizingSequence( + ScheduleOneComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index e230d538cc2df826778e8d13eaaaf31ec81c57f0..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 061210352cf12e6802d066d311fd2cb481673f15..e303999c63ff699487bc2362850459ab691f6bc8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -202,7 +202,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, << " of " << hlo.ToString(); llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); - alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); // The GPU backend emits one kernel per top-level HLO, and LLVM views // execution of one kernel as the "whole program" executed on the GPU. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 36a1b82a26d84fb557c894f0bf122aef064b052e..6c4519185b34989eb53c884ba214d69b824b113c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -77,15 +77,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (consumer->operand_count() == 2 && - (producer->opcode() == HloOpcode::kDot || - (producer->opcode() == HloOpcode::kFusion && - producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); HloInstruction* op1 = nullptr; HloInstruction* op2 = nullptr; - if (consumer->opcode() == HloOpcode::kFusion && + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && Match(consumer->fused_expression_root(), match::Op() @@ -103,10 +102,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, op2->opcode() != HloOpcode::kBroadcast) { return false; } - if (IsIEEEFloatingPointScalarConstant(alpha)) { + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { return true; } - } else if (consumer->opcode() == HloOpcode::kMultiply) { + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = consumer->operand(other_operand_index); // Fuse if 'alpha' is a broadcast of a scalar constant. if (alpha->opcode() == HloOpcode::kBroadcast && alpha->dimensions().empty() && @@ -173,6 +174,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 426b1d235c3135ff61671481044beed518e2db00..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -168,7 +168,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); EXPECT_THAT(root->fused_expression_root(), - op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } TEST_F(InstructionFusionTest, BitcastIntoAdd) { @@ -255,7 +255,7 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { EXPECT_THAT( root->fused_expression_root(), op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is @@ -339,7 +339,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); EXPECT_THAT(root->fused_expression_root(), op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Counts the HLO ops with a given op code in the specified module. @@ -581,5 +581,30 @@ TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { << module->ToString(); } +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 67890bfed1136796c83c7ef6912ffc1ab1b7e332..388aa35d7dceeef92dbdb6c8a3bb7fb3796a0b61 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && - !ShapeUtil::HasZeroElements(lhs_shape) && - !ShapeUtil::HasZeroElements(rhs_shape); + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); } bool DotImplementedAsGemm(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 547af33e9a98c03e1429366172f9a401e385a9d1..7b7dd673a5c35e586105f1a6253c72c3aa0b0151 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -610,7 +610,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { - if (ShapeUtil::HasZeroElements(convolution->shape())) { + if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. return Status::OK(); } @@ -620,7 +620,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { } Status IrEmitter::HandleFft(HloInstruction* fft) { - if (ShapeUtil::HasZeroElements(fft->shape())) { + if (ShapeUtil::IsZeroElementArray(fft->shape())) { // Emit no code for an empty output. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ed005f6afcc6bcd8c56a76301be67bb77ef91fb8..ccbd99a0420ae8d5183fa112468b3f7cc678503e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -1391,6 +1392,30 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce)); } +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); +} + Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, @@ -1402,7 +1427,7 @@ Status IrEmitterUnnested::EmitRowReduction( std::pair> extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1413,15 +1438,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) // partial_result = reducer(partial_result, input[z][y][z]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1448,29 +1473,44 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; + // element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, x += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; ++element_id_in_tile, x += warpSize) { + // if (x < width) + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1481,17 +1521,20 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + // Emit the loop body that reduces one z-x-tile. const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); std::vector partial_reduction_result_addresses; @@ -1506,9 +1549,7 @@ Status IrEmitterUnnested::EmitRowReduction( partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; llvm::Value* warp_id = ir_builder_.CreateUDiv( @@ -1516,107 +1557,131 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* lane_id = ir_builder_.CreateURem( x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } + lane_id, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(z_tile_size), z_tile)); + + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, + [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( + x, ir_builder_.getInt64(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); - } + return ksl.For("z_tile", + /*start=*/0, /*end=*/z_tile_size, /*step=*/1, + emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1666,16 +1731,24 @@ Status IrEmitterUnnested::EmitRowReduction( reduce_output_shapes[i]), &ir_builder_), &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + if (x_tile_size * z_tile_size < depth * width) { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } else { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } } return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); @@ -2132,6 +2205,10 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } +Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) { + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2440,7 +2517,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (alpha->opcode() == HloOpcode::kBroadcast) { alpha = alpha->operand(0); } - alpha = inst->operand(alpha->parameter_number()); + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } // TODO(b/74185543): Remove the following if block once we support fusion // with a non-constant as well. Then we will just always use the constant // on the device. @@ -2486,7 +2565,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value = [&] { + const HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: return inst->operand(2); @@ -2506,6 +2585,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( } }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } @@ -2562,6 +2642,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *hlo) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 202231b82f3877c11cf932bd00a8aac350fd0afa..d228be81d47906850fa98e22a1d974500a7d34ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleGenerateToken(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..d541776f00ca9c0986fecd272930e5585852f6f3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_instr = + [&](const HloInstruction* instr) -> const HloInstruction* { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduce operand of the fusion root, + // because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (inst->opcode() == HloOpcode::kReduce) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + return element_instr; + }; + + auto get_element_shape = [&](const HloInstruction* element_instr) { + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The shapes in all tuple operands should agree, unless it is a reduce. + // In that case, the operand of the reduce needs to have the same shape + // as the other tuple operands, but also we need to compare the output + // shapes of the reduces. + auto* element_instr_1 = get_element_instr(instr1); + auto* element_instr_2 = get_element_instr(instr2); + if (element_instr_1->opcode() == HloOpcode::kReduce && + element_instr_2->opcode() == HloOpcode::kReduce && + !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { + return false; + } + // The elementwise output shapes must be the same (including layout). + return ShapeUtil::Equal(get_element_shape(element_instr_1), + get_element_shape(element_instr_2)); +} + +namespace { +bool IsReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + } else { + return instr->opcode() == HloOpcode::kReduce; + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + // We can fuse reduces and loop fusions. + return IsReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop && + // TODO(b/110202584): bitcasts make nested fusions, GPU has no support + // for nested fusions. + instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { + return false; + } + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions + // merge into bigger loop fusions and input (reduce) fusions become fusions + // with multiple reduce outputs. We could fuse reduce and loop fusions + // together too (the result being an input fusion) if we find cases where this + // improves things. + CHECK(instr1->opcode() == HloOpcode::kFusion); + if (instr2->opcode() == HloOpcode::kFusion) { + return instr1->fusion_kind() == instr2->fusion_kind(); + } + return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..16db0e0f02d5cbf582f0e4236297b3d5407014b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e7ceb7976b5d1957f706c12ec255e93991344b8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + scalar_mul_computation { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10]{0} parameter(0) + ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1.3 = f32[10,10]{1,0} parameter(1) + fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1 + p2 = f32[] parameter(2) + fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { + // Verify that if we already have a multi-output fusion that we prefer to pick + // a reduce op from its operands for checking shape compatibility. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation + ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10] parameter(0) + ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[10,10]{1,0} parameter(1) + p2 = f32[10]{0} parameter(2) + fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1 + get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0 + get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1 + fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 696fa7e0194032b5c78bf11383c3280a62de07fa..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 06a5e0351b63270b61b998ca2211f480f256f759..a04aa4069d2344ca7b2e763cfeeb53abcbefc21d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,6 +26,46 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); + return result.heap_size; +} + /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -46,9 +86,11 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -219,6 +261,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -286,12 +334,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap* + memory_by_computation) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -460,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 8b2b43a37a5c41d334e5338c6a6fad160f03a51e..811a6042df9434ac3f4bed71b9c093433e25c1bb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -85,6 +85,23 @@ class HeapSimulator { const BufferValueFlatSet* buffers_to_assign; }; + // Returns the minimum memory required to compute an HLO module where all + // computations have been scheduled (represented by the given + // module_sequence), assuming no fragmentation. + static StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + + // Returns the minimum memory required to compute the given computation, + // assuming no fragmentation. + static StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -111,7 +128,9 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + const Options& options = Options(), + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -120,7 +139,9 @@ class HeapSimulator { HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence); + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); Status RunComputation( @@ -144,7 +165,13 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // set by hlo scheduling. Then, in RunComputation, we check both in order to + // handle subcomputations. It would be good to unify the handling of + // subcomputations, but it's not clear how. const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const tensorflow::gtl::FlatMap* + memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of // buffer sharing. When ShareBuffer is called, instead of allocating new @@ -189,6 +216,11 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + virtual void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) {} + // Free de-allocates a previously allocated buffer. virtual void Free(const BufferValue* buffer, int64 size) = 0; @@ -207,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { ~NoFragmentationStatsHeap() override = default; void Alloc(const BufferValue* buffer, int64 size) override; + + void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 6271652412c2979ff926702f12722102344b0dfb..93d7a141258a3186b10cf2728b70a034488a84f2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -34,6 +34,65 @@ limitations under the License. namespace xla { namespace { +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1f7c1cffd324ad2f4e4cdf11046c8459b8ceb6d5..e201359d3d25b7d2dda852762c6de1fcb75685d7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,6 +145,7 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; + repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index a88283ed9a6459b4fa9310e160b59c77d51f1027..0a948cc390fed7daed3e0cc938bf59cbcfd9b4df 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -493,6 +493,16 @@ StatusOr> HloAliasAnalysis::Run( bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { + CHECK(!buffer.values().empty()); + if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + // Tokens have no on-device representation and cannot interfere. + for (const HloValue* value : buffer.values()) { + // If one of the values is a token, all values must be a token. + DCHECK(ShapeUtil::IsToken(value->shape())); + } + continue; + } + // Check that the values in the buffer are totally ordered with respect to // 'ordering'. Begin by sorting the values with respect to 'ordering' with a // tie-break using value ID. The tie-break is necessary because we need a @@ -517,7 +527,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // a buffer and A interferes with C, then necessarily A also interferes // with B. So to check interference you only need to check interference // between A and B, and between B and C. - CHECK(!values.empty()); for (int i = 1; i < values.size(); ++i) { if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) { VLOG(1) << values[i - 1]->ToShortString() << " and " diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ed0ea39ff55dca6931fac6f93ddcddd2716ec505..ef8bb030fbc7a99e1fc907c0b1c1e9b0a16ecbd1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,7 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { @@ -234,7 +234,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); (*inst_it)->set_parent(nullptr); - instruction->DetachFromOperands(); instructions_.erase(inst_it); return Status::OK(); } @@ -357,7 +356,6 @@ std::list HloComputation::MakeInstructionPostOrder() const { std::list post_order; std::list trace_instructions; tensorflow::gtl::FlatSet added_instructions; - std::vector dfs_stack; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -525,21 +523,7 @@ HloInstruction* HloComputation::CreateFusionInstruction( StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { - if (ShapeUtil::IsArray(instruction->shape())) { - if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { - // Use kCopy to copy array elements - HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - if (copies_added != nullptr) { - *copies_added->mutable_element(*index) = copy; - } - return copy; - } else { - // Array elements which are not to be copied are passed through - // transparently. - return instruction; - } - } else if (ShapeUtil::IsTuple(instruction->shape())) { + if (ShapeUtil::IsTuple(instruction->shape())) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -556,9 +540,27 @@ StatusOr HloComputation::DeepCopyHelper( index->pop_back(); } return AddInstruction(HloInstruction::CreateTuple(elements)); + } + if (ShapeUtil::IsToken(instruction->shape())) { + // Tokens have no on-device representation and cannot be copied. Pass + // through transparently. + return instruction; + } + + // Array shape. + TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { + // Use kCopy to copy array elements + HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + if (copies_added != nullptr) { + *copies_added->mutable_element(*index) = copy; + } + return copy; } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); + // Elements which are not to be copied are passed through + // transparently. + return instruction; } } @@ -864,15 +866,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } context->MapComputation(this, result.get()); - // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before - // they're destroyed, otherwise they stick around in the operands' users lists - // and cause use-after-frees. - for (auto& kv : replacements) { - if (std::unique_ptr& new_instr = kv.second) { - new_instr->DetachFromOperands(); - } - } return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 25469a54c48f4f5cab478aba929f1cc18de8b81f..3f59d31bb9123a480864ddfca939ec3c032298c9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -371,6 +371,38 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { } } +TEST_F(HloComputationTest, DeepCopyToken) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); + + // No copy should be added. + EXPECT_THAT(copy, op::GenerateToken()); +} + +TEST_F(HloComputationTest, DeepCopyTokenTuple) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); + + // Only the array (second tuple element) should be copied. The token is passed + // through transparently. + EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), + op::Copy(op::GetTupleElement(tuple)))); +} + TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94c9c7eabcc99d4cf61f535925c068a9b55ed136..762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -172,15 +172,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice( + const HloInstruction* dynamic_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -386,6 +393,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0d66736fe1d0677d13a63ede7a203d6ac20c76f5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b9ade09ea31cda1f4cf5e8ff2c0a081..d22bef56730da194816b4ee89dc3196439b350f9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -460,5 +460,51 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-update-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), + builder.ConstantR1({1})); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index dab946a099fa0066a4a0d42ce29077b9de6a486e..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -135,17 +135,18 @@ StatusOr HloCSE::Run(HloModule* module) { // instruction for each class. tensorflow::gtl::FlatSet - representatives(/*N=*/1024, &CseHash, cse_equal); - + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { // If the instruction has zero operands (constants, parameters, etc.) skip // over it. if (instruction->operand_count() == 0) { continue; } - - // Skip instructions which have side effects. - if (instruction->HasSideEffect()) { + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index cc130a4900dc162d4b416116fbe879fec37136a2..d0200058683b2db8f5f0469d6c643014881f741e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -931,16 +931,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -967,6 +968,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( use.operand_number == other_add_operand_index; } } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, @@ -998,8 +1000,13 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } - // Check if 'user' is element-wise. - return user->IsElementwise(); + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 5798326dcbf65c3c34748afb02afab1dc7af9147..db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCantShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can't share with tuple element 1. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1e78d775c8e172a272a03fbd1101cef365e6dc2d..33424019b93feff862c6e3e268ae3980bacc9142 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -300,12 +300,6 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( instruction->CloneWithNewOperands(instruction->shape(), operands); auto result = Evaluate(cloned_instruction.get()); - // Clean up our cloned instructions before returning. - cloned_instruction->DetachFromOperands(); - for (auto& operand : owned_operands) { - operand->DetachFromOperands(); - } - return result; } @@ -321,7 +315,6 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( rhs_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } @@ -334,7 +327,6 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } @@ -372,7 +364,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); + CHECK(ShapeUtil::IsArray(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); @@ -383,7 +375,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); + CHECK(ShapeUtil::IsArray(operand_shape)); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -910,6 +902,11 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + evaluated_[token] = Literal::CreateToken(); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index b53d5644de5a17c52bdbf2593ce52f0227008a00..fc2fc9437b238a2e519401b2b121dfbef070e2dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -174,6 +174,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleGenerateToken(HloInstruction* token) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 84b4ead2dd28caa40b6d7830a1e1401be88b6b36..72eb9930e92c340ab9f42cd563c27507623b2ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1248,7 +1248,7 @@ void BM_ReducePrecisely(int num_iters) { HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 13f46407e33e36bdbef4c9032630101d6c18268f..bc7340aa036ecb322b37fbe4c72fa43485b2f57d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -778,7 +778,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); + CHECK(ShapeUtil::IsArray(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -1103,7 +1103,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), @@ -1116,7 +1116,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*padding_config=*/pad->padding_config())); CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); // Create new HLO of padded shape with padding value. @@ -1182,7 +1182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1237,7 +1237,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape(), update->shape(), start_indices->shape())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1393,7 +1393,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); @@ -1613,7 +1613,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanStringWithLayout(inferred_return_shape); const Literal& operand_literal = diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 61612bebd1e906d2d055e2f70de29da53275d4e8..ab224021c54fb3f5c5b69d0b633a080c304d5edd 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -723,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -735,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -806,26 +825,26 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. optional elem_count; - if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + if (ShapeUtil::IsArray(shape)) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { + if (elem_count.has_value() && *elem_count <= 8) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -841,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -897,11 +913,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -964,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -975,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: @@ -997,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e52d926d85f1ce6fabeb2dedd2f8e0fe0c2051d..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8d7604fae1e121b771915e0852ab44005da92fbe..0b4dd6412f189d12dfa9e343ef516854c08dc4c3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include -#include #include #include #include @@ -66,6 +65,9 @@ StatusOr> HloInstruction::CreateFromProto( const auto operands = [&instruction_map, &proto](int index) { return instruction_map.at(proto.operand_ids(index)); }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -86,6 +88,187 @@ StatusOr> HloInstruction::CreateFromProto( operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; + case HloOpcode::kFft: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSend(operands(0), proto.channel_id()); + break; + case HloOpcode::kSendDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + CHECK_EQ(proto.operand_ids_size(), 0); + instruction = + CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: { + CHECK_EQ(proto.dimensions_size(), 1); + std::vector concat_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + concat_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateConcatenate(proto.shape(), concat_operands, + proto.dimensions(0)); + break; + } + case HloOpcode::kReduce: + CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.called_computation_ids_size(), 1); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: { + CHECK_EQ(proto.called_computation_ids_size(), 1); + std::vector map_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + map_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateMap(proto.shape(), map_operands, computations(0)); + break; + } + case HloOpcode::kSlice: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + // TODO(b/110214922): Revert this to CHECK(proto.has_literal()). + if (proto.has_literal()) { + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + } else { + instruction = MakeUnique(proto.shape()); + } + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + case HloOpcode::kFusion: { + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within + // the HloModuleProto. + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(FusionKind fusion_kind, + StringToFusionKind(proto.fusion_kind())); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + std::vector fusion_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + fusion_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateFusion(proto.shape(), fusion_kind, fusion_operands, + fused_computation); + break; + } + case HloOpcode::kRng: { + std::vector rng_parms(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + rng_parms.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateRng(proto.shape(), proto.distribution(), rng_parms); + break; + } + case HloOpcode::kParameter: + instruction = CreateParameter(proto.parameter_number(), proto.shape(), + proto.name()); + break; + case HloOpcode::kGetTupleElement: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; + case HloOpcode::kInfeed: + instruction = CreateInfeed(proto.shape(), proto.infeed_config()); + break; + case HloOpcode::kOutfeed: + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + break; + case HloOpcode::kCrossReplicaSum: { + CHECK_EQ(proto.called_computation_ids_size(), 1); + std::vector all_operands(proto.operand_ids_size()); + c_transform(proto.operand_ids(), all_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateCrossReplicaSum( + proto.shape(), all_operands, computations(0), + /*replica_group_ids=*/ + std::vector(proto.replica_group_ids().begin(), + proto.replica_group_ids().end()), + /*barrier=*/""); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -99,59 +282,23 @@ StatusOr> HloInstruction::CreateFromProto( TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) ->AddControlDependencyTo(instruction.get())); } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } break; } } - // In the proto, fused computations are held exclusively within the - // HloInstructionProto and do not appear as an HloComputationProto within the - // HloModuleProto. - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(!proto.fusion_kind().empty()); - TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, - StringToFusionKind(proto.fusion_kind())); - - // Find the fused computation and set its fusion instruction. - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " - << proto.called_computation_ids_size(); - const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); - TF_RET_CHECK(fused_computation != nullptr) - << "No fusion computation with id " << fusion_id; - fused_computation->SetFusionInstruction(instruction.get()); - instruction->called_computations_.push_back(fused_computation); - } else { - for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; - instruction->called_computations_.push_back( - computation_map.at(computation_id)); - } - } - - if (instruction->opcode() == HloOpcode::kTrace) { - TF_RET_CHECK(instruction->operands().size() == 1) - << "Trace instruction should have 1 operand but sees " - << instruction->operands().size(); - instruction->mutable_operand(0)->set_tracing(instruction.get()); - } - TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); - + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } - instruction->parameter_number_ = proto.parameter_number(); - instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } if (proto.has_window()) { instruction->window_ = MakeUnique(proto.window()); } @@ -164,14 +311,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); } @@ -179,16 +319,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->padding_config_ = MakeUnique(proto.padding_config()); } - instruction->outfeed_config_ = proto.outfeed_config(); - instruction->distribution_ = proto.distribution(); - instruction->channel_id_ = proto.channel_id(); - instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); - instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -212,52 +343,29 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); - instruction->parameter_number_ = parameter_number; - instruction->name_ = name; - return instruction; + return MakeUnique(parameter_number, shape, name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - operand->set_tracing(instruction.get()); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); - instruction->distribution_ = distribution; - instruction->shape_ = shape; - for (HloInstruction* param : parameters) { - instruction->AppendOperand(param); - } - return instruction; + return MakeUnique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -372,13 +480,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + return MakeUnique(shape, operands, map_computation, + static_operands); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( @@ -404,11 +507,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -441,12 +540,8 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr @@ -454,92 +549,64 @@ HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id) { - // TODO(b/79737069): Remove the CHECK when supported. - CHECK(replica_group_ids.empty()); - CHECK(!channel_id.has_value()); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(reduce_computation); - return instruction; + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) { + return MakeUnique( + shape, operands, reduce_computation, replica_group_ids, barrier, + all_reduce_id); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& shape, const string& config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); - instruction->set_infeed_config(config); - return instruction; + return MakeUnique(shape, config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); - instruction->AppendOperand(operand); - instruction->outfeed_config_ = std::string(outfeed_config); - instruction->outfeed_shape_ = shape; - return instruction; + return MakeUnique(shape, operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(operand, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(shape, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -576,18 +643,8 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( @@ -618,13 +675,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -647,13 +698,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -674,8 +720,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return WrapUnique(new HloBatchNormTrainingInstruction( - shape, operand, scale, offset, epsilon, feature_index)); + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -683,8 +729,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return WrapUnique(new HloBatchNormInferenceInstruction( - shape, operand, scale, offset, mean, variance, epsilon, feature_index)); + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -693,9 +739,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return WrapUnique( - new HloBatchNormGradInstruction(shape, operand, scale, mean, variance, - grad_output, epsilon, feature_index)); + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -718,12 +764,8 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -802,53 +844,28 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); - return instruction; + return MakeUnique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->set_parent(fused_root->parent()); - instruction->set_metadata(fused_root->metadata()); - instruction->CloneAndFuseInternal(fused_root); - return instruction; + return MakeUnique(shape, fusion_kind, fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->called_computations_.push_back(fusion_computation); - fusion_computation->SetFusionInstruction(instruction.get()); - return instruction; + return MakeUnique(shape, fusion_kind, operands, + fusion_computation); } -void HloInstruction::set_device_sharding(int64 device) { - HloSharding device_sharding = HloSharding::AssignDevice(device); +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; if (ShapeUtil::IsTuple(shape())) { - set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape()))); + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); } else { - set_sharding(device_sharding); + set_sharding(sharding); } } @@ -862,289 +879,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->set_metadata(metadata_); } -HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { - CHECK_EQ(opcode(), HloOpcode::kFusion); - CHECK_EQ(operand_count(), - fused_instructions_computation()->parameter_instructions().size()); - const int64 param_no = operand_count(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); - HloInstruction* fused_parameter = - fused_instructions_computation()->AddParameter( - HloInstruction::CreateParameter(param_no, new_operand->shape(), - param_name)); - AppendOperand(new_operand); - return fused_parameter; -} - -void HloInstruction::MergeFusionInstruction( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); - // Clone the instruction from which to merge fused instructions. - std::unique_ptr clone = instruction_to_merge->Clone(); - // Replace uses of fused parameters with the corresponding operand of the - // fusion. Add all non-parameter fused instructions to 'unfused_instructions' - // to be merged into 'this'. This is done in reverse post order. - std::vector unfused_instructions; - auto fused_instructions = - clone->fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_it = fused_instructions.rbegin(); - fused_it != fused_instructions.rend(); ++fused_it) { - auto fused_instruction = *fused_it; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith( - clone->mutable_operand(fused_instruction->parameter_number()))); - } else { - unfused_instructions.push_back(fused_instruction); - } - } - CHECK(unfused_instructions.front() == clone->fused_expression_root()); - // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); - // Fuse 'unfused_instructions' into 'this'. - for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); - instruction->DetachFromOperands(); - } - CHECK_EQ(0, clone->user_count()); - clone->DetachFromOperands(); - TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( - clone->fused_instructions_computation())); -} - -void HloInstruction::MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. `old_to_new' maps the instructions in the fused node - // to the disaseembled fusion instructions. - // Note that we add the unfused instructions to this->parent_ computation. - // This is necessary because the unique_id needs for an instruction and - // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; - std::vector unfused_instructions; - auto computation_to_merge = - instruction_to_merge->fused_instructions_computation(); - auto post_order = computation_to_merge->MakeInstructionPostOrder(); - for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { - auto fused_instruction = *rit; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&old_to_new, fused_instruction, - instruction_to_merge->mutable_operand( - fused_instruction->parameter_number())); - continue; - } - - // Here we clone the insertion and call FuseInstructionIntoMultiOutput() - // which clones again. This can be improved. - auto cloned_instruction = - parent_->AddInstruction(fused_instruction->Clone()); - unfused_instructions.push_back(cloned_instruction); - InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); - } - for (auto unfused_instruction : unfused_instructions) { - for (int64 index = 0; index < unfused_instruction->operand_count(); - index++) { - auto new_operand = - FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); - TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); - } - } - - HloInstruction* unfused_root = unfused_instructions.front(); - TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); - - TF_CHECK_OK( - instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); - if (GetModule()) { - TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); - } - - // Fuse the root instruction and generate multiple outputs. - FuseInstructionIntoMultiOutput(unfused_root); - TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); - // The rest instructions are of normal fusing. - for (int64 i = 1; i < unfused_instructions.size(); i++) { - auto instruction = unfused_instructions[i]; - FuseInstruction(instruction); - TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); - } -} - -HloInstruction* HloInstruction::FuseInstructionInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // When add_output is false, this fusion instruction must be a user of - // instruction_to_fuse. - if (!add_output) { - CHECK(IsUserOf(instruction_to_fuse)); - } - HloInstruction* fused_instruction = - CloneAndFuseInternal(instruction_to_fuse, add_output); - return fused_instruction; -} - -HloInstruction* HloInstruction::CloneAndFuseInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); - VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); - HloInstruction* clone = nullptr; - if (called_computations_.empty()) { - // New fusion instruction. It should not be a multioutput instruction. - CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", this); - builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - called_computations_.push_back( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); - clone = fused_expression_root(); - } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); - // When add_output is false, instruction_to_fuse is necessarily an operand - // of the fusion instruction. After fusion this will no longer be the case. - // Remove the operand from the operand list and remove its corresponding - // fused parameter instruction. Renumber parameters as necessary to make - // parameter numbers consistent with their index in the - // fused_parameter_ vector. - bool in_operand_list = std::find(operands_.begin(), operands_.end(), - instruction_to_fuse) != operands_.end(); - CHECK(add_output || in_operand_list); - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - if (instruction_to_fuse == operands_[operand_num]) { - // replace the fused parameter instruction's uses with the clone. - HloInstruction* fused_parameter = fused_parameters[operand_num]; - TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); - - // Remove the corresponding fused parameter and operand from their - // respective vectors. - TF_CHECK_OK( - fused_instructions_computation()->RemoveParameter(operand_num)); - operands_.erase(operands_.begin() + operand_num); - break; - } - } - // We've cloned instruction_to_fuse into this fusion instruction, so this - // fusion instruction is no longer a use of instruction_to_fuse. - if (in_operand_list) { - instruction_to_fuse->RemoveUser(this); - // When the instruction_to_fuse does not have other users, we don't need - // to generate a multioutput fusion instruction. - if (instruction_to_fuse->user_count() == 0) { - add_output = false; - } - } - } - - // Reread the parameters in the computation. - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - - // Add each operand of the clone as an operand of the fusion instruction. A - // complication is that some clone operands may already be operands of the - // fusion instruction. - for (int64 operand_num = 0; operand_num < clone->operand_count(); - ++operand_num) { - HloInstruction* operand = clone->mutable_operand(operand_num); - - // See if this operand is already an operand of the fusion node. - CHECK_EQ(operands_.size(), fused_parameters.size()); - HloInstruction* fused_param = nullptr; - for (int64 i = 0; i < operands_.size(); ++i) { - if (operands_[i] == operand) { - fused_param = fused_parameters[i]; - break; - } - } - - if (fused_param == nullptr) { - // Clone's operand was not already an operand of the fusion - // instruction. Add it as an operand and add a corresponding fused - // parameter instruction. - fused_param = AddFusionOperand(operand); - } - TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); - } - - if (add_output) { - CHECK_GT(instruction_to_fuse->user_count(), 0); - // If this is already a multioutput fusion instruction, expand the root - // tuple by 1. - HloInstruction* fused_root = fused_expression_root(); - HloInstruction::InstructionVector tuple_elements; - bool newly_created_tuple_instr = false; - if (fused_root->opcode() == HloOpcode::kTuple) { - tuple_elements = fused_root->operands(); - } else { - tuple_elements.push_back(fused_root); - newly_created_tuple_instr = true; - } - if (clone->opcode() == HloOpcode::kTuple) { - for (auto inst : clone->operands()) { - tuple_elements.push_back(inst); - } - } else { - tuple_elements.push_back(clone); - } - HloInstruction* new_root = fused_instructions_computation()->AddInstruction( - HloInstruction::CreateTuple(tuple_elements)); - fused_instructions_computation()->set_root_instruction(new_root); - shape_ = new_root->shape(); - if (fused_root->opcode() == HloOpcode::kTuple) { - TF_CHECK_OK( - fused_instructions_computation()->RemoveInstruction(fused_root)); - } - - // If this is a newly created multioutput instruction, we need to update - // the use of the original fusion instruction. - if (newly_created_tuple_instr) { - HloInstruction* new_instr = parent_->AddInstruction( - HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); - } - int64 index = tuple_elements.size(); - if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); - std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { - CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); - int64 old_tuple_index = old_gte->tuple_index(); - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - old_gte->shape(), this, index + old_tuple_index)); - TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); - to_be_removed.push_back(old_gte); - } - for (auto old_gte : to_be_removed) { - TF_CHECK_OK(parent_->RemoveInstruction(old_gte)); - } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); - } else { - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - clone->shape(), this, index - 1)); - TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); - } - } - - VLOG(2) << "New clone:\n" << clone->ToString(); - return clone; -} - -RandomDistribution HloInstruction::random_distribution() const { - CHECK_EQ(opcode_, HloOpcode::kRng); - return distribution_; -} - bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: @@ -1287,6 +1021,28 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1347,10 +1103,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; @@ -1369,9 +1121,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateHostCompute(shape, new_operands, channel_name_, cost_estimate_ns_); break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1380,11 +1129,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -1395,30 +1139,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); - break; - case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); - break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); - break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); clone = CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], @@ -1430,22 +1155,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateSelectAndScatter(shape, new_operands[0], select(), *window_, new_operands[1], new_operands[2], scatter()); break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; - case HloOpcode::kRng: - clone = CreateRng(shape, distribution_, new_operands); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; case HloOpcode::kDynamicSlice: clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); @@ -1455,10 +1168,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1468,60 +1177,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; - case HloOpcode::kFusion: { - HloModule* module = context != nullptr ? context->module() : GetModule(); - HloComputation* new_fused_computation = nullptr; - if (context != nullptr) { - new_fused_computation = - context->FindComputation(fused_instructions_computation()); - } - if (new_fused_computation == nullptr) { - new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", context)); - } - clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), - /*operands=*/new_operands, - /*fusion_computation=*/new_fused_computation); - break; - } - case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, name_); - break; - case HloOpcode::kInfeed: - CHECK_EQ(new_operands.size(), 0); - clone = CreateInfeed(shape, infeed_config()); - break; - case HloOpcode::kOutfeed: - CHECK_EQ(new_operands.size(), 1); - clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); - break; - case HloOpcode::kSendDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. - clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); - break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); - break; case HloOpcode::kGather: CHECK_EQ(new_operands.size(), 2); clone = CreateGather(shape, new_operands[0], new_operands[1], @@ -1533,8 +1194,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); @@ -1550,7 +1212,29 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return clone; } -HloInstruction::~HloInstruction() {} +HloInstruction::~HloInstruction() { + // Detach from operands. An instruction may be repeated as an operand. To + // avoid calling RemoveUser twice on the same operand, check before remove. + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + HloInstruction* operand = operands_[operand_num]; + if (operand == nullptr) { + continue; + } + if (operand->user_set_.find(this) != operand->user_set_.end()) { + operand->RemoveUser(this); + } + operands_[operand_num] = nullptr; + } + + // Update users. Set `nullptr` to the correpsonding operand slot for users. + for (auto& user : this->users()) { + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operands_[i] == this) { + user->operands_[i] = nullptr; + } + } + } +} std::unique_ptr HloInstruction::Clone( const string& suffix, HloCloneContext* context) const { @@ -1615,40 +1299,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1737,10 +1387,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1809,36 +1455,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - // Broadcast, Concatenate, and Transpose need the same dimensions field. - case HloOpcode::kBroadcast: - case HloOpcode::kConcatenate: - case HloOpcode::kTranspose: - return dimensions() == other.dimensions(); - - case HloOpcode::kFusion: - return fusion_kind() == other.fusion_kind() && - eq_computations(fused_instructions_computation(), - other.fused_instructions_computation()); - // These opcodes have complex or special behavior so just return false. case HloOpcode::kDomain: - case HloOpcode::kRng: - case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; - case HloOpcode::kParameter: - return parameter_number() == other.parameter_number(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1855,16 +1477,6 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); case HloOpcode::kReduceWindow: return eq_computations(to_apply(), other.to_apply()) && protobuf_util::ProtobufEquals(window(), other.window()); @@ -1877,19 +1489,14 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); // Remaining instructions with special values. - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); case HloOpcode::kPad: return protobuf_util::ProtobufEquals(padding_config(), other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; case HloOpcode::kCall: case HloOpcode::kCrossReplicaSum: - case HloOpcode::kMap: - return eq_computations(to_apply(), other.to_apply()); + return replica_group_ids() == other.replica_group_ids() && + cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && + eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: if ((window_ == nullptr) != (other.window_ == nullptr) || (window_ != nullptr && @@ -1905,20 +1512,12 @@ bool HloInstruction::IdenticalSlowPath( return false; } return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); // These opcodes are not yet supported. - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kHostCompute: return false; @@ -1927,19 +1526,32 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } } -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); -} - void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -2021,22 +1633,6 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { return Status::OK(); } -void HloInstruction::DetachFromOperands() { - VLOG(3) << "DetachFromOperands:\n " << ToString(); - CHECK_EQ(0, user_count()); - // An instruction may be repeated as an operand. To avoid calling RemoveUser - // twice on the same operand, keep a set of already detached operands. - std::set detached_operands; - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - HloInstruction* operand = operands_[operand_num]; - if (!ContainsKey(detached_operands, operand)) { - operand->RemoveUser(this); - detached_operands.insert(operand); - } - operands_[operand_num] = nullptr; - } -} - HloComputation* HloInstruction::to_apply() const { switch (opcode_) { case HloOpcode::kCall: @@ -2061,6 +1657,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -2075,11 +1672,6 @@ const string& HloInstruction::custom_call_target() const { return custom_call_target_; } -const string& HloInstruction::outfeed_config() const { - CHECK_EQ(opcode_, HloOpcode::kOutfeed); - return outfeed_config_; -} - HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); return called_computations_[kConditionComputationIndex]; @@ -2179,6 +1771,71 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { return ToStringWithCanonicalNameMap(options, &new_map); } +bool HloInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + switch (opcode_) { + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); + return true; + + // Binary elementwise operations, the same as in IsElementwiseBinary(). + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + return !ShapeUtil::IsTuple(shape_); + case HloOpcode::kClamp: + return true; + + default: + return false; + } +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -2229,76 +1886,44 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - // - // In HloInstruction, sometimes a constant literal is not constructed due - // to its size. Skip the printing in this case. - if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants())) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; + tensorflow::gtl::ArraySlice slice(operands_); + const int64 kMaxOperandsToShowIfCompact = 4; + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { + slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + } + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + // If operand is already been deleted, put `null` to the string output. + if (operand == nullptr) { + StrAppend(out, "null "); + return; } - } else if (opcode() == HloOpcode::kParameter) { - StrAppend(&operands, parameter_number_); - } else { - tensorflow::gtl::ArraySlice slice(operands_); - const int64 kMaxOperandsToShowIfCompact = 4; - if (options.compact_operands() && - slice.size() > kMaxOperandsToShowIfCompact) { - slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - std::vector str; - if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); - } - // In a top-level HloInstruction::ToString() call, the operand name is not - // part of the canonical string. - if (options.canonicalize_instruction_names() && - options.is_in_nested_computation()) { - str.push_back(PrintName( - canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { - str.push_back(PrintName(operand->name(), options)); - } - StrAppend(out, Join(str, " ")); - }); - const int64 remaining = operands_.size() - slice.size(); - if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); + }); + const int64 remaining = operands_.size() - slice.size(); + if (slice.size() != operands_.size()) { + StrAppend(&operands, ", ...(+", remaining, ")"); } return operands; } std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; - if (opcode() == HloOpcode::kFusion) { - extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); - } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } + std::vector extra = ExtraAttributesToStringImpl(options); if (window_ != nullptr && window_->dimensions_size() != 0) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } @@ -2306,19 +1931,7 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } + if (opcode() == HloOpcode::kDynamicSlice) { extra.push_back( StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); @@ -2337,10 +1950,6 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2396,6 +2005,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2411,14 +2021,7 @@ std::vector HloInstruction::ExtraAttributesToString( break; } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); - } - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2431,26 +2034,12 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { - extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); - } - if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { - extra.push_back( - StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); - } - if (opcode() == HloOpcode::kRng) { - extra.push_back( - StrCat("distribution=", RandomDistributionToString(distribution_))); - } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); - } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), "\", entry=", operand_side_metadata_->ToString(), ", exit=", user_side_metadata_->ToString(), "}")); } + // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -2489,24 +2078,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } - proto.set_parameter_number(parameter_number_); - if (opcode() == HloOpcode::kFusion) { - proto.set_fusion_kind(xla::ToString(fusion_kind())); - proto.add_called_computation_ids( - fused_instructions_computation()->unique_id()); - } else { + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } if (window_ != nullptr) { *proto.mutable_window() = *window_; } @@ -2525,32 +2102,14 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { proto.add_dynamic_slice_sizes(slice_size); } if (padding_config_ != nullptr) { *proto.mutable_padding_config() = *padding_config_; } - proto.set_outfeed_config(outfeed_config_); - if (opcode() == HloOpcode::kRng) { - proto.set_distribution(distribution_); - } - proto.set_channel_id(channel_id_); - proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); - *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2610,12 +2169,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { @@ -2634,51 +2187,6 @@ bool HloInstruction::IsFusable() const { } } -HloComputation* HloInstruction::fused_instructions_computation() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(!called_computations_.empty()); - auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()) - << "Computation " << fused_instructions_computation->name() - << " is not a fusion kind"; - return fused_instructions_computation; -} - -HloInstruction* HloInstruction::fused_expression_root() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->root_instruction(); -} - -HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instruction( - parameter_number); -} - -const std::vector& HloInstruction::fused_parameters() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instructions(); -} - -const tensorflow::gtl::iterator_range>::const_iterator>> -HloInstruction::fused_instructions() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - const HloComputation* subcomp = fused_instructions_computation(); - return subcomp->instructions(); -} - -const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> -HloInstruction::fused_instructions() { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->instructions(); -} - -int64 HloInstruction::fused_instruction_count() const { - return fused_instructions_computation()->instruction_count(); -} - HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), @@ -2857,6 +2365,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGather(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3097,12 +2607,6 @@ Status HloInstruction::AcceptOrdered( return visitor->FinishVisit(this); } -const Shape& HloInstruction::outfeed_shape() const { - DCHECK_EQ(opcode_, HloOpcode::kOutfeed); - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return outfeed_shape_; -} - const Shape& HloInstruction::shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; @@ -3124,87 +2628,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kTanh: - CHECK_EQ(1, operand_count()); - return true; - - // Binary elementwise operations, the same as in IsElementwiseBinary(). - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(2, operand_count()); - return true; - - // Ternary elementwise operations. - case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); - case HloOpcode::kClamp: - return true; - - // Other operations. - case HloOpcode::kRng: - case HloOpcode::kMap: - return true; - case HloOpcode::kFusion: - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - for (auto* fused : fused_instructions()) { - if (fused->opcode() != HloOpcode::kParameter && - !fused->IsElementwise()) { - return false; - } - } - return true; - - default: - return false; - } + return IsElementwiseImpl(tensorflow::gtl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -3212,54 +2636,8 @@ bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } -namespace { -bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, - const HloInstruction* operand) { - std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); -} -} // namespace - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { - // For all instructions other than kFusion, being elementwise on one of the - // operands is equivalent to being elementwise on all the operands. - if (opcode() != HloOpcode::kFusion) { - return IsElementwise(); - } - - CHECK_EQ(HloOpcode::kFusion, opcode()); - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - - // A loop-fusion is elementwise on an operand if all operations (computed - // using BFS) between the operand and the fused root are elementwise. - std::deque worklist; - std::unordered_set visited; - worklist.push_back(fused_parameter(operand_idx)); - visited.insert(fused_parameter(operand_idx)); - while (!worklist.empty()) { - HloInstruction* operand = worklist.front(); - worklist.pop_front(); - for (HloInstruction* user : operand->users()) { - CHECK_GE(user->unique_id(), 0); - if (ContainsKey(visited, user)) { - continue; - } - if (user->IsElementwise() || - IsInstructionElementwiseOnOperand(user, operand)) { - worklist.push_back(user); - visited.insert(user); - } else { - return false; - } - } - } - return true; + return IsElementwiseImpl(operand_idx); } // A helper class for memoized, recursive computation of HloOpcode::kFusion @@ -3281,8 +2659,10 @@ class HloInstruction::FusionReusesParamElements { static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, tensorflow::gtl::FlatMap* cache) { - if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { - return UseKind::kUse; + if (auto hlo_param = DynCast(&hlo)) { + if (hlo_param->parameter_number() == i) { + return UseKind::kUse; + } } auto p = cache->emplace(&hlo, UseKind{}); @@ -3591,30 +2971,207 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} + +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} + +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); +} + +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + void HloInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + Cast(this)->RelayoutConstant(new_layout, shape_index); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; - } +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + return Cast(this)->AddFusionOperand(new_operand); } -// TODO(b/80131774): Remove these temporary methods after transition. -int64 HloInstruction::feature_index() const { - return Cast(this)->feature_index(); +// Delegates to HloFusionInstruction::MergeFusionInstruction. +void HloInstruction::MergeFusionInstruction( + HloInstruction* instruction_to_merge) { + return Cast(this)->MergeFusionInstruction( + Cast(instruction_to_merge)); } -float HloInstruction::epsilon() const { - return Cast(this)->epsilon(); +// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. +void HloInstruction::MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge) { + return Cast(this) + ->MergeFusionInstructionIntoMultiOutput( + Cast(instruction_to_merge)); +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstruction(instruction_to_fuse); +} + +HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstructionIntoMultiOutput( + instruction_to_fuse); +} + +HloComputation* HloInstruction::fused_instructions_computation() const { + return Cast(this)->fused_instructions_computation(); +} + +HloInstruction* HloInstruction::fused_expression_root() const { + return Cast(this)->fused_expression_root(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloInstruction::fused_instructions() const { + return Cast(this)->fused_instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { + return Cast(this)->fused_instructions(); +} + +int64 HloInstruction::fused_instruction_count() const { + return Cast(this)->fused_instruction_count(); +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + return Cast(this)->fused_parameter(parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + return Cast(this)->fused_parameters(); +} + +const bool HloInstruction::IsMultiOutputFusion() const { + const HloFusionInstruction* fusion = DynCast(this); + return fusion != nullptr && fusion->IsMultiOutputFusion(); +} + +HloInstruction::FusionKind HloInstruction::fusion_kind() const { + return Cast(this)->fusion_kind(); +} + +void HloInstruction::set_fusion_kind(FusionKind kind) { + return Cast(this)->set_fusion_kind(kind); +} + +RandomDistribution HloInstruction::random_distribution() const { + return Cast(this)->random_distribution(); +} + +int64 HloInstruction::parameter_number() const { + return Cast(this)->parameter_number(); +} + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} + +string HloInstruction::infeed_config() const { + return Cast(this)->infeed_config(); +} + +void HloInstruction::set_infeed_config(const string& config) { + return Cast(this)->set_infeed_config(config); +} + +const Shape& HloInstruction::outfeed_shape() const { + return Cast(this)->outfeed_shape(); +} + +const string& HloInstruction::outfeed_config() const { + return Cast(this)->outfeed_config(); +} + +const std::vector& HloInstruction::replica_group_ids() const { + return Cast(this)->replica_group_ids(); +} + +string HloInstruction::cross_replica_sum_barrier() const { + return Cast(this)->cross_replica_sum_barrier(); +} + +void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); +} + +tensorflow::gtl::optional HloInstruction::all_reduce_id() const { + return Cast(this)->all_reduce_id(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b16837eaec7fda36827095b01c15cb4f84f81333..8a0ffc21cd49270316619022a243bf8e16ed1d98 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -435,16 +435,17 @@ class HloInstruction { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // `all_reduce_id`: for Allreduce nodes from different modules, if they have + // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and @@ -664,6 +665,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr CreateGenerateToken( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, @@ -802,9 +808,6 @@ class HloInstruction { // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. @@ -821,13 +824,6 @@ class HloInstruction { // root to new_producer. Status ReplaceAllUsesWith(HloInstruction* new_producer); - // Detaches an instruction from its operands. That is, remove the instruction - // from each operand's user set. This should only be called prior to - // deallocating the instruction. - // - // TODO(b/78305363): Make this automatic when deleting an instruction. - void DetachFromOperands(); - // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only @@ -873,38 +869,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - - // Returns whether there is literal associated with this instruction. - bool HasLiteral() const; - - // Returns the parameter number associated with this instruction. - // - // Note: only parameter opcodes have an associated parameter number. - int64 parameter_number() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_number_; - } - - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -936,14 +900,6 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kCustomCall const string& custom_call_target() const; - // Returns the config for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const string& outfeed_config() const; - - // Returns the shape for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const Shape& outfeed_shape() const; - // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -992,7 +948,7 @@ class HloInstruction { string OperandsToString(const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. - virtual std::vector ExtraAttributesToString( + std::vector ExtraAttributesToString( const HloPrintOptions& options) const; // As ToString, but returns a shorter string. @@ -1011,105 +967,20 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - // Returns the channel name associated with the instruction. The name is // used to identify host Send/Recv operations. // // Precondition: opcode() == HloOpcode::kHostCompute string channel_name() const { return channel_name_; } - // Delegates to HloBatchNormInstruction::feature_index. - // TODO(b/80131774): Remove this code. - int64 feature_index() const; - - // Delegates to HloBatchNormInstruction::epsilon. - // TODO(b/80131774): Remove this code. - float epsilon() const; - - // Returns the infeed configuration string. The infeed configuration includes - // any metadata needed for the backend compiler (e.g., infeed buffer address) - // and is target-dependent. - string infeed_config() const { return infeed_config_; } - void set_infeed_config(const string& config) { infeed_config_ = config; } - - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the root instruction of the fused expression contained within this - // fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_expression_root() const; - - // Returns the list of fused instructions inside this fusion instruction. The - // returned type is a range of HloInstruction*s. - // - // Precondition: opcode() == HloOpcode::kFusion - const tensorflow::gtl::iterator_range>::const_iterator>> - fused_instructions() const; - - const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - fused_instructions(); - - // Gets the number of instructions inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - int64 fused_instruction_count() const; - - // Returns the fused parameter instruction in this fusion instruction - // corresponding to the given parameter number. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_parameter(int64 parameter_number) const; - - // Returns the vector of fused parameters inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - const std::vector& fused_parameters() const; - - // Returns true if this instruction is a fusion instruction that generates - // multiple outputs. - const bool IsMultiOutputFusion() const { - return opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple; - } - - FusionKind fusion_kind() const { - CHECK_EQ(HloOpcode::kFusion, opcode_); - return fusion_kind_; - } - - void set_fusion_kind(FusionKind kind) { - CHECK_EQ(HloOpcode::kFusion, opcode_); - fusion_kind_ = kind; - } - // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { @@ -1134,8 +1005,11 @@ class HloInstruction { void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. - void set_device_sharding(int64 device); + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. @@ -1165,91 +1039,17 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Adds a new operand the fusion instruction. - HloInstruction* AddFusionOperand(HloInstruction* new_operand); - - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this', updating operands as necessary. - // - // Precondition: opcode() == HloOpcode::kFusion - // Predondition: 'instruction_to_merge' must be an operand of 'this'. - void MergeFusionInstruction(HloInstruction* instruction_to_merge); - - // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. - // All the users of instruction_to_merge will be redirected to 'this' - // instruction. instruction_to_merge will be removed from its parent - // computation. - // - // Precondition: opcode() == HloOpcode::kFusion - void MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge); - - // Fuses the given instruction in this fusion instruction. instruction_to_fuse - // is cloned and the clone is placed in the fusion - // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather - // than moved to cleanly handle the case where the instruction has a use - // outside the fusion instruction. Moving such an instruction into a fusion - // instruction would violate the single-result invariant of HLO instructions - // and significantly complicate code generation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse); + // TODO(b/80249101): Remove these methods once HLO scheduling and copy + // insertion are integrated, and we don't need to run a separate pass + // of copy elision anymore. + bool CopyElisionAllowed() const { + CHECK_EQ(HloOpcode::kCopy, opcode_); + return copy_elision_allowed_; } - // Fuses the given instruction in this fusion instruction and generate - // multioutput fusion instruction. A clone of the instruction_to_fuse will - // be part of the output of fusion instructions. The users of - // instruction_to_fuse will be redirected to this fusion instructions. - // instruction_to_fuse will be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionIntoMultiOutput( - HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); - } - - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; - } - const std::vector& slice_starts() const { return slice_starts_; } - - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; - } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; - } - - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; + void SetCopyElisionAllowed(bool value) { + CHECK_EQ(HloOpcode::kCopy, opcode_); + copy_elision_allowed_ = value; } // Returns the size of the slice in the given dimension for a dynamic @@ -1265,22 +1065,6 @@ class HloInstruction { return dynamic_slice_sizes_; } - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -1318,16 +1102,6 @@ class HloInstruction { MakeUnique(dnums); } - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; - } - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1350,11 +1124,6 @@ class HloInstruction { // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; - // Returns the random distribution for this rng node. - // - // Precondition: opcode() == HloOpcode::kRng - RandomDistribution random_distribution() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1437,9 +1206,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1520,13 +1294,160 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + + // Delegates to HloFusionInstruction::AddFusionOperand. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Delegates to HloFusionInstruction::MergeFusionInstruction. + void MergeFusionInstruction(HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. + void MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::FuseInstruction. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::fused_instruction. + HloComputation* fused_instructions_computation() const; + + // Delegates to HloFusionInstruction::fused_expression_root. + HloInstruction* fused_expression_root() const; + + // Delegates to HloFusionInstruction::fused_instructions. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Delegates to HloFusionInstruction::fused_instruction_count. + int64 fused_instruction_count() const; + + // Delegates to HloFusionInstruction::fused_parameter. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Delegates to HloFusionInstruction::fused_parameters. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const; + + // Delegates to HloFusionInstruction::fusion_kind. + FusionKind fusion_kind() const; + + // Delegates to HloFusionInstruction::set_fusion_kind. + void set_fusion_kind(FusionKind kind); + + // Delegates to HloRngInstruction::random_distribution. + RandomDistribution random_distribution() const; + + // Delegates to HloParameterInstruction::parameter_number. + int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Delegates to HloReducePrecisionInstruction::exponent_bits. + int32 exponent_bits() const; + + // Delegates to HloReducePrecisionInstruction::mantissa_bits. + int32 mantissa_bits() const; + + // Delegates to HloInfeedInstruction::infeed_config. + string infeed_config() const; + + // Delegates to HloInfeedInstruction::set_infeed_config. + void set_infeed_config(const string& config); + + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const; + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const; + + // Delegates to HloAllReduceInstruction::replica_group_ids. + const std::vector& replica_group_ids() const; + + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. + string cross_replica_sum_barrier() const; + void set_cross_replica_sum_barrier(const string& barrier); + + // Delegates to HloAllReduceInstruction::all_reduce_id. + tensorflow::gtl::optional all_reduce_id() const; + // Old methods kept for smooth subclassing transition END. + protected: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); @@ -1535,6 +1456,16 @@ class HloInstruction { // of the operand. void AppendOperand(HloInstruction* operand); + void RemoveOperandAt(int index) { + operands_.erase(operands_.begin() + index); + } + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + + void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( @@ -1544,6 +1475,20 @@ class HloInstruction { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + + // Implementation for IsElementwise if operand_idx is nullopt and for + // IsElementwiseOnOperand if otherwise. + // + // NOTE: For all instructions other than kFusion, being elementwise on one of + // the operands is equivalent to being elementwise on all the operands. + virtual bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1554,7 +1499,7 @@ class HloInstruction { CanonicalNameMap* canonical_name_map) const; // Prints an operand to a string. - string OperandsToStringWithCanonicalNameMap( + virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; @@ -1562,11 +1507,6 @@ class HloInstruction { // OperandsToStringWithCanonicalNameMap() functions. friend class HloComputation; - enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; - - // Helper class for computing OperandElementUse for kFusion. - class FusionReusesParamElements; - // See comments on Identical(). virtual bool IdenticalSlowPath( const HloInstruction& other, @@ -1584,38 +1524,6 @@ class HloInstruction { // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Fuses the given instruction into this fusion instruction. When add_output - // is false (which is the default), instruction_to_fuse is cloned and the - // clone is placed in the fusion instruction. instruction_to_fuse is - // unchanged. - // - // When add_output is true, a clone of the instruction_to_fuse will be part - // of the output of fusion instructions. The users of instruction_to_fuse - // will be redirected to this fusion instructions. instruction_to_fuse will - // be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones the given instruction_to_fuse and insert the clone into this fusion - // instruction. If add_output is true, a clone of instruction_to_fuse will - // be in the output of the this fusion instruction (part of the tuple of the - // fusion root). - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones a fusion instruction with a new shape and operands. - std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloCloneContext* context = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; - // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; @@ -1647,22 +1555,9 @@ class HloInstruction { // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; - // Shape of outfeed request. - Shape outfeed_shape_; - // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; @@ -1675,23 +1570,8 @@ class HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; + // Used to tag kCopy instructions that are eligible for copy elision. + bool copy_elision_allowed_ = true; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). @@ -1701,9 +1581,6 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The type of the fusion. Used by kFusion only. - FusionKind fusion_kind_; - // The sharding, if one exists. std::unique_ptr sharding_; @@ -1711,9 +1588,6 @@ class HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; - // For parameter instructions this field holds the parameter number. - int64 parameter_number_ = 0; - // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1742,26 +1616,12 @@ class HloInstruction { kFalseComputationIndex = 1, }; - // Outfeed configuration information, only present for kOutfeed. - string outfeed_config_; - // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; - // The distribution requested for random number generation. - // Only present for kRng. - RandomDistribution distribution_; - - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; - - // The string representation of the infeed configuration. - string infeed_config_; - // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). string backend_config_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 313033ddadce6a49936f8d34d38f33e923dc2e35..5d6f8b931f0c665fba03e1c845214fa83aabf12e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -342,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); module->AddEntryComputation(builder.Build()); @@ -381,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -980,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index adbebb135bafb443aa27302df2b88f8a43b5ee6c..5871a6605fed24865d8cbe7e1cee5a4d5fadb357 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -15,10 +15,31 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + namespace xla { +namespace { +using ::tensorflow::str_util::CEscape; +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + HloBatchNormInstruction::HloBatchNormInstruction( HloOpcode opcode, const Shape& shape, HloInstruction* operand, HloInstruction* scale, float epsilon, int64 feature_index) @@ -38,13 +59,6 @@ bool HloBatchNormInstruction::IdenticalSlowPath( epsilon() == casted_other.epsilon(); } -std::vector HloBatchNormInstruction::ExtraAttributesToString( - const HloPrintOptions& options) const { - std::vector extra = {StrCat("epsilon=", epsilon()), - StrCat("feature_index=", feature_index())}; - return extra; -} - HloInstructionProto HloBatchNormInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_epsilon(epsilon_); @@ -52,6 +66,12 @@ HloInstructionProto HloBatchNormInstruction::ToProto() const { return proto; } +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) @@ -115,4 +135,1298 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( new_operands[4], epsilon(), feature_index()); } +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape( + {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + channel_id) { + AppendOperand(operand); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(new_operands[0], channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + channel_id) {} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::GetTupleElementShape(operand->shape(), 0), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) + : HloInstruction(HloOpcode::kCrossReplicaSum, shape), + replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(!all_reduce_id_.has_value()); + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 i : replica_group_ids_) { + proto.add_replica_group_ids(i); + } + // TODO(b/79737069): handle barrier and all_reduce_id. + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result = { + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + if (all_reduce_id_.has_value()) { + result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); + } + return result; +} + +bool HloAllReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return replica_group_ids() == casted_other.replica_group_ids() && + eq_computations(to_apply(), casted_other.to_apply()) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier() && + all_reduce_id() == casted_other.all_reduce_id(); +} + +std::unique_ptr +HloAllReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, to_apply(), replica_group_ids(), + cross_replica_sum_barrier(), all_reduce_id()); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) + : HloInstruction(HloOpcode::kMap, shape) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloConstantInstruction::HloConstantInstruction(const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + return proto; +} + +bool HloConstantInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if (literal_ != nullptr && + ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} + +HloFusionInstruction::HloFusionInstruction(const Shape& shape, + FusionKind fusion_kind, + HloInstruction* fused_root) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + CHECK(fused_root != nullptr); + SetAndSanitizeName("fusion"); + set_parent(fused_root->parent()); + set_metadata(fused_root->metadata()); + CloneAndFuseInternal(fused_root); +} + +HloFusionInstruction::HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName("fusion"); + AppendComputation(fusion_computation); + fusion_computation->SetFusionInstruction(this); +} + +HloInstructionProto HloFusionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fusion_kind(xla::ToString(fusion_kind())); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); + return proto; +} + +bool HloFusionInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + + if (!operand_idx.has_value()) { + for (auto* fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { + return false; + } + } + return true; + } + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx.value())); + visited.insert(fused_parameter(operand_idx.value())); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + CHECK_GE(user->unique_id(), 0); + if (ContainsKey(visited, user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction* HloFusionInstruction::AddFusionOperand( + HloInstruction* new_operand) { + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + +void HloFusionInstruction::MergeFusionInstruction( + HloFusionInstruction* instruction_to_merge) { + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); + // Clone the instruction from which to merge fused instructions. + std::unique_ptr cloned = instruction_to_merge->Clone(); + HloFusionInstruction* cloned_fusion = + static_cast(cloned.get()); + // Replace uses of fused parameters with the corresponding operand of the + // fusion. Add all non-parameter fused instructions to + // 'unfused_instructions' to be merged into 'this'. This is done in reverse + // post order. + std::vector unfused_instructions; + auto fused_instructions = cloned_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder(); + for (auto fused_it = fused_instructions.rbegin(); + fused_it != fused_instructions.rend(); ++fused_it) { + auto fused_instruction = *fused_it; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK( + fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand( + fused_instruction->parameter_number()))); + } else { + unfused_instructions.push_back(fused_instruction); + } + } + CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + // Replace instruction_to_merge use of 'this' with unfused_root. + TF_CHECK_OK( + instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + // Fuse 'unfused_instructions' into 'this'. + for (auto& instruction : unfused_instructions) { + FuseInstruction(instruction); + } + CHECK_EQ(0, cloned_fusion->user_count()); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + cloned_fusion->fused_instructions_computation())); +} + +void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge) { + // Add all non-parameter fused instructions to 'unfused_instructions' to be + // merged into 'this'. `old_to_new' maps the instructions in the fused node + // to the disaseembled fusion instructions. + // Note that we add the unfused instructions to this->parent_ computation. + // This is necessary because the unique_id needs for an instruction and + // it's only added when inserting to the computation. + tensorflow::gtl::FlatMap old_to_new; + std::vector unfused_instructions; + auto computation_to_merge = + instruction_to_merge->fused_instructions_computation(); + auto post_order = computation_to_merge->MakeInstructionPostOrder(); + for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { + auto fused_instruction = *rit; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&old_to_new, fused_instruction, + instruction_to_merge->mutable_operand( + fused_instruction->parameter_number())); + continue; + } + + // Here we clone the insertion and call FuseInstructionIntoMultiOutput() + // which clones again. This can be improved. + auto cloned_instruction = + parent()->AddInstruction(fused_instruction->Clone()); + unfused_instructions.push_back(cloned_instruction); + InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); + } + for (auto unfused_instruction : unfused_instructions) { + for (int64 index = 0; index < unfused_instruction->operand_count(); + index++) { + auto new_operand = + FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); + TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); + } + } + + HloInstruction* unfused_root = unfused_instructions.front(); + TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); + + TF_CHECK_OK( + instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); + } + + // Fuse the root instruction and generate multiple outputs. + FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); + // The rest instructions are of normal fusing. + for (int64 i = 1; i < unfused_instructions.size(); i++) { + auto instruction = unfused_instructions[i]; + FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); + } +} + +HloComputation* HloFusionInstruction::fused_instructions_computation() const { + CHECK(!called_computations().empty()); + auto* fused_instructions_computation = called_computations().front(); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; + return fused_instructions_computation; +} + +HloInstruction* HloFusionInstruction::fused_expression_root() const { + return fused_instructions_computation()->root_instruction(); +} + +HloInstruction* HloFusionInstruction::fused_parameter( + int64 parameter_number) const { + return fused_instructions_computation()->parameter_instruction( + parameter_number); +} + +const std::vector& HloFusionInstruction::fused_parameters() + const { + return fused_instructions_computation()->parameter_instructions(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloFusionInstruction::fused_instructions() const { + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloFusionInstruction::fused_instructions() { + return fused_instructions_computation()->instructions(); +} + +int64 HloFusionInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + +HloInstruction* HloFusionInstruction::FuseInstructionInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + // When add_output is false, this fusion instruction must be a user of + // instruction_to_fuse. + if (!add_output) { + CHECK(IsUserOf(instruction_to_fuse)); + } + HloInstruction* fused_instruction = + CloneAndFuseInternal(instruction_to_fuse, add_output); + return fused_instruction; +} + +HloInstruction* HloFusionInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); + HloInstruction* clone = nullptr; + if (called_computations().empty()) { + // New fusion instruction. It should not be a multioutput instruction. + CHECK(!add_output); + auto builder = HloComputation::Builder("fused_computation", this); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + AppendComputation( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + clone = fused_expression_root(); + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + // When add_output is false, instruction_to_fuse is necessarily an operand + // of the fusion instruction. After fusion this will no longer be the + // case. Remove the operand from the operand list and remove its + // corresponding fused parameter instruction. Renumber parameters as + // necessary to make parameter numbers consistent with their index in the + // fused_parameter_ vector. + bool in_operand_list = std::find(operands().begin(), operands().end(), + instruction_to_fuse) != operands().end(); + CHECK(add_output || in_operand_list); + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operand(operand_num)) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters[operand_num]; + TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + TF_CHECK_OK( + fused_instructions_computation()->RemoveParameter(operand_num)); + RemoveOperandAt(operand_num); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + if (in_operand_list) { + DetachFrom(instruction_to_fuse); + // When the instruction_to_fuse does not have other users, we don't need + // to generate a multioutput fusion instruction. + if (instruction_to_fuse->user_count() == 0) { + add_output = false; + } + } + } + + // Reread the parameters in the computation. + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands().size(), fused_parameters.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands().size(); ++i) { + if (this->operand(i) == operand) { + fused_param = fused_parameters[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + fused_param = AddFusionOperand(operand); + } + TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); + } + + if (add_output) { + CHECK_GT(instruction_to_fuse->user_count(), 0); + // If this is already a multioutput fusion instruction, expand the root + // tuple by 1. + HloInstruction* fused_root = fused_expression_root(); + HloInstruction::InstructionVector tuple_elements; + bool newly_created_tuple_instr = false; + if (fused_root->opcode() == HloOpcode::kTuple) { + tuple_elements = fused_root->operands(); + } else { + tuple_elements.push_back(fused_root); + newly_created_tuple_instr = true; + } + if (clone->opcode() == HloOpcode::kTuple) { + for (auto inst : clone->operands()) { + tuple_elements.push_back(inst); + } + } else { + tuple_elements.push_back(clone); + } + HloInstruction* new_root = fused_instructions_computation()->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + fused_instructions_computation()->set_root_instruction(new_root); + *mutable_shape() = new_root->shape(); + if (fused_root->opcode() == HloOpcode::kTuple) { + TF_CHECK_OK( + fused_instructions_computation()->RemoveInstruction(fused_root)); + } + + // If this is a newly created multioutput instruction, we need to update + // the use of the original fusion instruction. + if (newly_created_tuple_instr) { + HloInstruction* new_instr = parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); + TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + } + int64 index = tuple_elements.size(); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + index -= instruction_to_fuse->operand_count(); + std::vector to_be_removed; + for (auto old_gte : instruction_to_fuse->users()) { + CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); + int64 old_tuple_index = old_gte->tuple_index(); + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + old_gte->shape(), this, index + old_tuple_index)); + TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); + to_be_removed.push_back(old_gte); + } + for (auto old_gte : to_be_removed) { + TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); + } + TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); + } else { + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + clone->shape(), this, index - 1)); + TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); + } + } + + VLOG(2) << "New clone:\n" << clone->ToString(); + return clone; +} + +std::vector HloFusionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("kind=", xla::ToString(fusion_kind()))}; +} + +bool HloFusionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); +} + +std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + return MakeUnique(shape, fusion_kind(), new_operands, + new_fused_computation); +} + +HloRngInstruction::HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) + : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { + for (HloInstruction* param : parameters) { + AppendOperand(param); + } +} + +HloInstructionProto HloRngInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_distribution(distribution_); + return proto; +} + +std::vector HloRngInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("distribution=", RandomDistributionToString(distribution_))}; +} + +bool HloRngInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +bool HloRngInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, distribution_, new_operands); +} + +HloParameterInstruction::HloParameterInstruction(int64 parameter_number, + const Shape& shape, + const string& name) + : HloInstruction(HloOpcode::kParameter, shape), + parameter_number_(parameter_number) { + SetAndSanitizeName(name); +} + +HloInstructionProto HloParameterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_parameter_number(parameter_number_); + return proto; +} + +string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + return StrCat(parameter_number_); +} + +bool HloParameterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return parameter_number() == casted_other.parameter_number(); +} + +std::unique_ptr +HloParameterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(parameter_number_, shape, name()); +} + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {} + +HloInstructionProto HloInfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_infeed_config(infeed_config_); + return proto; +} + +std::vector HloInfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (infeed_config_.empty()) { + return {}; + } + return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")}; +} + +bool HloInfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique(shape, infeed_config()); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()), + outfeed_shape_(shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), shape)) + << "Outfeed shape " << shape << " must be compatible with operand shape " + << operand->shape(); + AppendOperand(operand); +} + +HloInstructionProto HloOutfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_outfeed_config(outfeed_config()); + *proto.mutable_outfeed_shape() = outfeed_shape(); + return proto; +} + +std::vector HloOutfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (outfeed_config_.empty()) { + return {}; + } + return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")}; +} + +bool HloOutfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(outfeed_shape(), new_operands[0], + outfeed_config()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 6fcd96a8c66bcb74b22a8fd5152ed3e3680ce576..04df2d860ebe2cd1b7f94a78598295d87b29986f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -32,19 +32,18 @@ class HloBatchNormInstruction : public HloInstruction { // number added to the variance to avoid divide-by-zero error. float epsilon() const { return epsilon_; } - // Returns string representation of op-specific attributes. - std::vector ExtraAttributesToString( - const HloPrintOptions& options) const override; - // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; protected: - HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, - HloInstruction* operand, HloInstruction* scale, - float epsilon, int64 feature_index); + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -58,9 +57,11 @@ class HloBatchNormInstruction : public HloInstruction { class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { public: - HloBatchNormTrainingInstruction(const Shape& shape, HloInstruction* operand, - HloInstruction* scale, HloInstruction* offset, - float epsilon, int64 feature_index); + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -72,11 +73,10 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { public: - HloBatchNormInferenceInstruction(const Shape& shape, HloInstruction* operand, - HloInstruction* scale, - HloInstruction* offset, HloInstruction* mean, - HloInstruction* variance, float epsilon, - int64 feature_index); + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -88,20 +88,760 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { class HloBatchNormGradInstruction : public HloBatchNormInstruction { public: - HloBatchNormGradInstruction(const Shape& shape, HloInstruction* operand, - HloInstruction* scale, HloInstruction* mean, - HloInstruction* variance, - HloInstruction* grad_output, float epsilon, - int64 feature_index); + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloAllReduceInstruction : public HloInstruction { + public: + explicit HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = + tensorflow::gtl::nullopt); + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + tensorflow::gtl::optional all_reduce_id() const { + return all_reduce_id_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + + // For Allreduce nodes from different modules, if they have the same + // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross modules. + tensorflow::gtl::optional all_reduce_id_; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Used when the literal is too large and dropped. + explicit HloConstantInstruction(const Shape& shape); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloFusionInstruction : public HloInstruction { + public: + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + HloInstruction* fused_root); + + explicit HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Merges the fused instructions from 'instruction_to_merge' into the + // fused instruction set of 'this', updating operands as necessary. + // + // Predondition: 'instruction_to_merge' must be an operand of 'this'. + void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); + + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. + void MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge); + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse); + } + + // Fuses the given instruction in this fusion instruction and generate + // multioutput fusion instruction. A clone of the instruction_to_fuse will + // be part of the output of fusion instructions. The users of + // instruction_to_fuse will be redirected to this fusion instructions. + // instruction_to_fuse will be removed from its parent computation. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); + } + + // Returns the computation for this fused instruction. + HloComputation* fused_instructions_computation() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + HloInstruction* fused_expression_root() const; + + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. + int64 fused_instruction_count() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Returns the vector of fused parameters inside this fusion instruction. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return fused_expression_root()->opcode() == HloOpcode::kTuple; + } + + FusionKind fusion_kind() const { return fusion_kind_; } + + void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } private: + // Fuses the given instruction into this fusion instruction. When add_output + // is false (which is the default), instruction_to_fuse is cloned and the + // clone is placed in the fusion instruction. instruction_to_fuse is + // unchanged. + // + // When add_output is true, a clone of the instruction_to_fuse will be part + // of the output of fusion instructions. The users of instruction_to_fuse + // will be redirected to this fusion instructions. instruction_to_fuse will + // be removed from its parent computation. + HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. If add_output is true, a clone of instruction_to_fuse will + // be in the output of the this fusion instruction (part of the tuple of the + // fusion root). + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const override; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; }; +class HloRngInstruction : public HloInstruction { + public: + explicit HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + // Returns the random distribution for this rng node. + RandomDistribution random_distribution() const { return distribution_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The distribution requested for random number generation. + RandomDistribution distribution_; +}; + +class HloParameterInstruction : public HloInstruction { + public: + explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, + const string& name); + int64 parameter_number() const { return parameter_number_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 parameter_number_ = 0; +}; + +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; + +class HloInfeedInstruction : public HloInstruction { + public: + explicit HloInfeedInstruction(const Shape& shape, const string& config); + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The string representation of the infeed configuration. + string infeed_config_; +}; + +class HloOutfeedInstruction : public HloInstruction { + public: + explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return outfeed_shape_; + } + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const { return outfeed_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Shape of outfeed request. + Shape outfeed_shape_; + // Outfeed configuration information, only present for kOutfeed. + string outfeed_config_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c570b420c21fed4d7828feb24ee5c7859db94a79..8a31a8e617c1fb82201e07d9a3ff1ab9a618206b 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -187,6 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); +HLO_MATCHER(GenerateToken); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index e63424c2dfb6c7b9e71e4cede896a8f6609fea62..9c59374b4a9d7e3dbfb99d8a6b30d4230e553658 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,15 +32,6 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -234,8 +225,7 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -287,8 +277,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); + auto module = MakeUnique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -401,7 +390,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -525,8 +514,6 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix, config_); - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index c93c74d34a95cfbb3d0d334fb1c1f40a5aad69e9..757e65bda286d983d05e5a791aa7dffe97bac945 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -57,10 +56,6 @@ namespace xla { // attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation @@ -126,10 +121,6 @@ class HloModule { return config_.device_entry_computation_layout(); } - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; - } - // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -188,9 +179,7 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. @@ -264,10 +253,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 4f1715e4cafd1a7a2d8626dc3ad386813e5c2d76..bf33640db16638803f4f8e6c66f35d6bb6e2c9fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -127,9 +127,14 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { for (HloInstruction* instruction : *companions) { // Go through all the communicating instructions (send, recv) of the given // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } std::unordered_set comm_devices; - for (HloInstruction* comm_instruction : - tracked_instructions_comms_.at(instruction)) { + for (HloInstruction* comm_instruction : it->second) { auto device = GetInstructionDevice(*comm_instruction); TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() << " does not have a device"; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 1fe06ee0c0d14255b8358fb998bfd8d0b029506f..a35546f5f41b149d119ee141fd734da8bfd055b2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -81,6 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..774345124b4ad62e35d9423a23f1dbaa28e44d80 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index dcd4725fe78e8b9b5d14437e964cb5aaf1664117..6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition( << " and def is in FALSE computation"; return true; } + if (value.defining_instruction() == use.instruction) { + VLOG(4) << " use is conditional " << use << " and def is " + << value.ToShortString(); + return true; + } } VLOG(4) << " use is not before value"; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index ee526d8dd7f7e81b3a846741d3e452935f486bd2..985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // // A sequence of instructions for each computation in the module. using HloModuleSequence = tensorflow::gtl::FlatMap to_apply; + optional> replica_group_ids; + optional barrier; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply)); + + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "")); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "")); + } break; } case HloOpcode::kReshape: { @@ -606,6 +620,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -777,6 +799,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1137,7 +1162,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 08068dc5042d58abe5ca97a4eac91afe2040015b..d551400d1ec62d659399e930529e4a4aa7bfaa7d 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -765,7 +765,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -913,11 +913,29 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add } )" }, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add +} + +)" +} }); // clang-format on } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bd1d9935bd37ff71064a1f8f431b2ddf9c7c789d..62c07d7fac93618a83b3b6111aec1e93309a0761 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -1201,7 +1202,8 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes) { + int64 memory_limit_bytes, RematerializationSizes* sizes, + bool run_copy_elision) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1230,12 +1232,21 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( *module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); + if (run_copy_elision) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + } + // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1338,9 +1349,10 @@ StatusOr HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes) { + RematerializationSizes* sizes, bool run_copy_elision) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + run_copy_elision); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ee2dd0571ae8c6604e4ca722351fd48a913bda5..59b4cf5dcc761f70767ce4d7ff0959448f29939a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -57,6 +57,12 @@ class HloRematerialization { // sizes: Optional outparam that indicates the peak memory usage of the HLO // module before/after rematerialization. // + // run_copy_elision: Enable copy elision. This pass is used to eliminate + // copies that were inserted before HLO scheduling. + // + // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy + // insertion is integrated with HLO scheduling. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -68,7 +74,7 @@ class HloRematerialization { const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes = nullptr); + RematerializationSizes* sizes, bool run_copy_elision = true); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -83,7 +89,8 @@ class HloRematerialization { // contains the memory-minimizing order in which to emit the HLO instructions. StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes); + int64 memory_limit, RematerializationSizes* sizes, + bool run_copy_elision); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83de54f3fa56ee660b79d8c366dbc0b52f9fde87..7a46da6efe0df23129d56e16355cf66aceb68ffe 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 8); + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), DefaultMemoryScheduler, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 68b2cde83a2eb479d9ba71fc6eab9ac9ab1c8267..641b9ecec9c55ab0d14c28a5c5e84b00c2322499 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -36,29 +36,6 @@ using ::tensorflow::strings::HumanReadableNumBytes; namespace xla { -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - namespace { // Class implementing a list scheduler of HLO instructions which produces a @@ -398,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -416,18 +393,6 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -533,29 +498,29 @@ StatusOr> DefaultMemoryScheduler( std::vector list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( std::vector post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 post_order_memory, - MinimumMemoryForComputation(computation, post_order_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -576,10 +541,9 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -587,12 +551,13 @@ CreateMemoryMinimizingSequence(const HloModule& module, for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = - MinimumMemoryForComputation(*computation, one_computation_sequence, - *points_to_analysis, size_function) + HeapSimulator::MinimumMemoryForComputation( + *computation, one_computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) .ValueOrDie(); sequence[computation] = std::move(one_computation_sequence); } @@ -600,15 +565,15 @@ CreateMemoryMinimizingSequence(const HloModule& module, return sequence; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap empty_map; - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 49b927eefd24f4e26df781dd8d2b977bedba2b80..2b33ccc8bfb895286bb3747aab0a16cf25e2cfae 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -28,20 +28,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function @@ -89,14 +75,13 @@ StatusOr> DefaultMemoryScheduler( // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); -// Overload of above that computes the sequence for a single computation. +// Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index db7ef6f0d4bd96216ea07ccc75a51513822bf2e3..73f22f81f4e9cf597db8b184642acff2fdaaf2b0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -31,65 +32,6 @@ limitations under the License. namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); -} - class HloSchedulingTest : public HloTestBase {}; TEST_F(HloSchedulingTest, LastUseScheduledFirst) { @@ -124,7 +66,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -165,7 +107,7 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.at(module->entry_computation()).size()); @@ -203,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // ROOT %subtract = f32[4]{0} subtract( // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) // } - // %SubcomputationsNotAccounted () -> f32[2,4] { + // %ListAccountsForSubcomputations () -> f32[2,4] { // %constant.3 = f32[2,4]{1,0} constant( // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) // %transpose = f32[2,4]{1,0} transpose( @@ -269,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }, - ListMemoryScheduler)); + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been @@ -287,6 +229,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The max mem doesn't change + // because the while body isn't live during the peak. + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); } TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { @@ -318,12 +278,12 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleComputationsInModule(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -368,7 +328,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationsInModule( *module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), 2); @@ -384,5 +344,70 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); } +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations + EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 58224ef870096a774d5892b9aa12c38f5ff511bd..9fb15df7c26951fb7f0d62b0d6533d6312e7a4d5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,34 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(tile_shape, assignment); } +HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve(sub_shardings.leaf_count()); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + if (flattened_list.empty()) { + // Empty tuple sharding ends up having no leaves, but we want to allow + // empty tuple HLO instruction results to have sharding, so we fetch the + // root ({}) sharding value from the ShapeTree. + // A ShapeTree created with ShapeTree(shape, init) will have + // init as value at its root. + flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Tuple( + const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) + << "Flat list has " << flattened_list.size() << ", required " + << RequiredLeaves(tuple_shape); + return HloSharding(flattened_list); +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -123,24 +151,49 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +int64 HloSharding::RequiredLeaves(const Shape& shape) { + // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are + // concerned, but they do have a single tuple_elements_ entry since we want + // to allow empty tuple results to have sharding. + return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); +} + +Status HloSharding::CheckLeafCount(const Shape& shape) const { + int64 shape_leaves = RequiredLeaves(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return Status::OK(); +} + StatusOr> HloSharding::AsShapeTree( const Shape& shape) const { if (IsTuple()) { ShapeTree result(shape, HloSharding::Replicate()); - int64 num_leaves = result.leaf_count(); - TF_RET_CHECK(num_leaves == tuple_elements_.size()) - << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves - << " leaf nodes while this sharding has " << tuple_elements_.size(); + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); auto it = tuple_elements_.begin(); for (auto& index_to_sharding : result.leaves()) { index_to_sharding.second = *it++; } + if (ShapeUtil::IsEmptyTuple(shape)) { + // Empty tuples have no leaves, but we want to assign them a sharding + // anyway, so we use the root element sharding. + *result.mutable_element(ShapeIndex({})) = *it; + } return std::move(result); } else { return ShapeTree(shape, *this); } } +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -182,28 +235,12 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } - // The easiest way to get the number of elements in a nested tuple is just to - // create a shape tree. We could call GetAsShapeTree, but that will try and - // apply our tuple_shardings_ to the shape tree, and that might cause a crash - // at this point as we haven't validated them. - ShapeTree bool_shape_tree(shape, false); - int64 num_leaves = - std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); - if (num_leaves != tuple_elements_.size()) { - return tensorflow::errors::InvalidArgument( - StrCat("Validation tuple shape has ", num_leaves, - " leaf elements, but this sharding contains ", - tuple_elements_.size(), " elements.")); - } + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); // Now we've validated the number of tuple elements, it's safe to request a // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { - if (index_to_sharding.first.empty()) { - // An empty tuple has a ShapeTree with a single leaf at the empty index. - continue; - } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -389,6 +426,19 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, : sub_shape_tree.element(ShapeIndex({})); } +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } + } + return tuple_elements_.front(); +} + std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { out << sharding.ToString(); return out; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index f4a0fb626f2c3e417c020cbfa2f7168359a47788..6a744e0247273e25c5de3143b7bbba2b79ee816a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -70,26 +70,13 @@ class HloSharding { // Creates a new sharding for a tuple type. The given ShapeTree must have // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree& sub_shardings) { - std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } + static HloSharding Tuple(const ShapeTree& sub_shardings); - // Creates a new sharding for a tuple type. The requested tuple shape must not - // be nested. For nested tuples, use the ShapeTree overload. + // Creates a new sharding for a tuple type. The number of elements in + // shardings must match the number of leaf nodes in tuple_shape. For + // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)); - CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); - std::vector flattened_list(shardings.begin(), shardings.end()); - CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); - return HloSharding(flattened_list); - } + tensorflow::gtl::ArraySlice shardings); // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -172,6 +159,18 @@ class HloSharding { // REQUIRES: IsTuple() HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && @@ -260,11 +259,19 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Checks that the number of elements in tuple_elements_ is consistent with + // the tuple shape passes as argument. + Status CheckLeafCount(const Shape& shape) const; + // Internal helper to validate a tuple sharding. Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + // Returns the number of tuple_elements_ entries to fit the shape. + static int64 RequiredLeaves(const Shape& shape); + bool replicated_; bool maximal_; bool tuple_; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 82cff2a4b7146c2d454feb2d90673d419ca1a54d..748273a43cecca7a9c7392bb84f0e4c7133cfb14 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -31,32 +31,22 @@ struct PassThrough { HloInstruction* operand = nullptr; }; -void SetDeviceSharding(HloInstruction* instruction, int64 device) { - VLOG(4) << " " << instruction->name() << " to device " << device; - instruction->set_device_sharding(device); -} - -tensorflow::gtl::optional ShardingUniqueDevice( - const HloSharding& sharding) { - if (sharding.IsTileMaximal()) { - auto device = sharding.UniqueDevice(); - if (device.ok()) { - return device.ValueOrDie(); - } - } - return tensorflow::gtl::optional(); +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); } bool ShardingMatches(const HloSharding& sharding1, const HloSharding& sharding2) { - auto device1 = ShardingUniqueDevice(sharding1); - if (device1) { - auto device2 = ShardingUniqueDevice(sharding2); - if (device2) { - return *device1 == *device2; + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; } } - // Anything which is not tile maximal with unique device, gets a full sharding + // Anything which is not unique across all elements, gets a full sharding // compare. return sharding1 == sharding2; } @@ -119,21 +109,21 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, std::unique_ptr CloneShardingForDomain( const HloSharding& sharding) { - auto device = ShardingUniqueDevice(sharding); - if (!device) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { return MakeUnique(sharding); } - return MakeUnique(HloSharding::AssignDevice(*device)); + return MakeUnique(*single_sharding); } -Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain, - int64 device) { - VLOG(4) << "Applying device " << device << " sharding"; +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; for (HloInstruction* instruction : domain.instructions) { // We only change instructions without sharding, since otherwise we might // mess up with eventual HLO passes which has knowledge of it. if (!instruction->has_sharding()) { - SetDeviceSharding(instruction, device); + SetSingleSharding(instruction, sharding); } else { VLOG(4) << " " << instruction->name() << " already has sharding " << instruction->sharding(); @@ -186,12 +176,15 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, const HloSharding* tuple_sharding = GetOperandSharding(tuple, domain, sharding); if (tuple_sharding != nullptr) { - TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString(); - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } ++assigned; } } else if (instruction->opcode() == HloOpcode::kTuple) { @@ -242,12 +235,29 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, Status ApplyDomainSharding(const DomainMetadata::Domain& domain, const HloSharding& sharding) { - auto device = ShardingUniqueDevice(sharding); - if (device) { - // Shortcut the simple case. We have a unique device sharding, so we call - // the ApplyDomainDeviceSharding() API which will apply array or tuple - // shaped device sharding to the domain instructions. - return ApplyDomainDeviceSharding(domain, *device); + // Here is the place to call external sharding normalizers, which are + // implemented in other modules (ie, spatial partitioning). + // The signature of the external normalizer function should be something + // like: + // + // StatusOr Normalizer(const DomainMetadata::Domain&, + // const HloSharding& sharding); + // + // The function should return true if it has processed the domain + // normalization, false if domain was not one recognized by it, or an error. + // We will call the functions in order below, and fall back to local code if + // none of the external normalizers acted on the domain. + // External normalizers should not handle the cases that are already handled + // locally. + + // None of the external normalizers handled the domain sharding, try to see + // whether this is a single sharding first. + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; for (;;) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index ee7133689b15348a18e6db9181199d5b25bf8143..54b7402b866361748d9eb35182b0bf486c4c9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -321,8 +321,10 @@ TEST_F(HloShardingTest, ParseHloString) { check(HloSharding::AssignDevice(2)); check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), Array4D({{{{0}, {1}}}}))); - // Empty tuple. - check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {})); + // Empty tuple. One sharding is required for empty tuples, as we need to be + // able to assign sharding to them, even though they have no leaves. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), + {HloSharding::Replicate()})); { // Non-nested tuple. auto tuple_shape = diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9cfd8a9bf74bc69ac40b1e0974d9e084d31071c9..1d6cd4cb2308fd09c7511e390a146a5224f253a3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -426,6 +426,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, + ShapeInference::InferGenerateTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -791,6 +800,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); +} + +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -851,6 +900,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1392a78097aa026b2f7cffa2b0135402d3ca7ae5..7283b3e7dcdbed5be18a1da1571287cf0c089288 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 429c8503432b79f46aa0e5b1970bb565093128dd..abedb4063d3763516e66cff36633dbd90c8cafde 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7067b6f86a0fb24fb946ad236bca9bbd48d53722..eb469e77a08b976b91ed5e3cdea304a8148f73c5 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -937,6 +937,11 @@ LayoutAssignment::LayoutAssignment( ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { + if (channel_layout_constraints_ != nullptr) { + // Save a copy of the input ChannelLayoutConstraints so that we can reset it + // if we have to undo previous operations (ClearPreviousPassSideEffects()). + channel_constraints_ = *channel_layout_constraints_; + } VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. @@ -1614,13 +1619,57 @@ Status LayoutAssignment::RunOnComputation( // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. + if (channel_constraints != nullptr) { + TF_RETURN_IF_ERROR( + ConstrainChannelLayouts(computation, channel_constraints)); + } + return Status::OK(); +} + +Status LayoutAssignment::ConstrainChannelLayouts( + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { + // We go through the kRecvDone before. These must either impose their layout, + // of find a matching one already existing (ConstrainChannel() returns + // nullptr). for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRecvDone) { + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), instruction->shape().layout()); + TF_RET_CHECK(layout == nullptr) + << instruction->ToString() + << " cannot constrain layout as it was set to " + << LayoutUtil::HumanString(*layout); + } + } + // After that we go through the kSend. These are likely going to have a kCopy + // as operand (otherwise we add it), so in case the constrained layout does + // not match, we can change the kCopy layout (and the kSend one as well). + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kSend) { - channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->operand(0)->shape().layout()); - } else if (instruction->opcode() == HloOpcode::kRecvDone) { - channel_constraints->ConstrainChannel(instruction->channel_id(), - instruction->shape().layout()); + HloInstruction* operand = instruction->mutable_operand(0); + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), operand->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the kSend wants to impose. Eitehr add a new kCopy, or use the + // existing one to marshal the correct shape. + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + Shape* send_shape = + ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); + *send_shape = shape; + } } } return Status::OK(); @@ -1743,6 +1792,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); } + ResetChannelConstraints(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index c287cca0c54ba1bb514bd8d243c137eca99b258f..eb4cd5936b09145c7ba6351fdc9086d6d0f05bea 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -249,25 +249,30 @@ class ChannelLayoutConstraints { // Given `shape`, apply the layout for `channel_id`. `channel_id` must already // be constrained. Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - *shape.mutable_layout() = constraints_.at(channel_id); + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + *shape.mutable_layout() = it->second; return shape; } // Returns the layout constraint for `channel_id`, which must already be // constrained. - Layout LayoutForChannel(int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - return constraints_.at(channel_id); + const Layout& LayoutForChannel(int64 channel_id) const { + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + return it->second; } // Adds a new layout constraint for `channel_id`. If a constraint for - // `channel_id` already exists, this operation requires that the new layout is - // the same as the previously constrained layout. - void ConstrainChannel(int64 channel_id, const Layout& layout) { - CHECK(!IsChannelConstrained(channel_id) || - LayoutUtil::Equal(layout, constraints_[channel_id])); - constraints_[channel_id] = layout; + // `channel_id` has been added, this API returns nullptr, otherwise returns + // the layout which has already been set for the channel. + const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) { + auto it = constraints_.emplace(std::make_pair(channel_id, layout)); + if (it.second) { + return nullptr; + } + return LayoutUtil::Equal(layout, it.first->second) ? nullptr + : &it.first->second; } private: @@ -464,6 +469,20 @@ class LayoutAssignment : public HloPassInterface { // itself). Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); + // Apply the channel layout constraints by populating the channel_constraints + // data structure passed in at constructor time. Eventually adds copies in + // case two ends of a channel ended up with a different leyout. + Status ConstrainChannelLayouts(HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); + + // Resets the input ChannelLayoutConstraints to the original copy received + // from the constructor input. + void ResetChannelConstraints() { + if (channel_layout_constraints_ != nullptr) { + *channel_layout_constraints_ = channel_constraints_; + } + } + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller @@ -474,7 +493,14 @@ class LayoutAssignment : public HloPassInterface { // here. tensorflow::gtl::FlatSet added_copies_; - ChannelLayoutConstraints* channel_layout_constraints_; + // The pointer to the channel layout constraints passed in with the + // constructor. If not nullptr, this is an input/output argument. + ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; + + // A copy of the input layout constraints used to reset the above pointer in + // case we have to undo operations due to the multiple passes over the + // computations/instructions. + ChannelLayoutConstraints channel_constraints_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index bf0448a67674f24591d866b646b98aea09ebb12c..62599b376a12808232c703479a0ccfd7a59aa9ad 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -52,10 +52,18 @@ using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(entry_computation_layout); + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr) { + LayoutAssignment layout_assignment( + entry_computation_layout, /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } + + std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + auto minor_to_major = + FindInstruction(module, name)->shape().layout().minor_to_major(); + return std::vector(minor_to_major.begin(), minor_to_major.end()); + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - auto layout_of = [&](tensorflow::StringPiece name) { - return FindInstruction(module.get(), name) - ->shape() - .layout() - .minor_to_major(); - }; - - EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); EXPECT_THAT(FindInstruction(module.get(), "gte1") ->shape() .tuple_shapes(0) @@ -816,5 +817,44 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { "Unexpected bitcast operation seen during layout assignment")); } +TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1} + ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, + sharding={maximal device=1} + send = (f32[2,2], u32[]) send(gte), channel_id=1, + sharding={maximal device=0} + send-done = () send-done(send), channel_id=1, sharding={maximal device=0} + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0)); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::GetSubshape( + FindInstruction(module.get(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 21bca1d6beff5b2804531724b94b123d4523c173..f200a08a3cd7e33351ec4607d67d40e7ab28f3b9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -32,7 +32,8 @@ static const BufferAllocation* kParameterAllocation = new BufferAllocation( LogicalBuffer::Color(0)); void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array) { + llvm_ir::IrArray* array, + const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; if (hlo.opcode() == HloOpcode::kParameter) { // Parameters may alias with each other but may not alias with our temporary @@ -40,7 +41,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { const std::set slices = - assignment_.GetAllSlices(&hlo, /*index=*/{}); + assignment_.GetAllSlices(&hlo, index); if (slices.empty() || slices.size() > 1) { // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their @@ -137,16 +138,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( // 2. Operands of users of the given hlo. // 3. Operands of the given hlo. // - // This set can be increased as we need. For now only consider top-level - // buffers (index = {}) not buffers nested within the instruction's - // operands/output which are not typically touched. + // This set can be increased as we need. std::vector worklist; auto add_buffers_to_worklist = [&worklist, &assignment](const HloInstruction* instruction) { - for (const LogicalBuffer* buffer : - assignment.GetSourceBuffers(instruction, /*index=*/{})) { - worklist.push_back(buffer); - } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& /*shape*/, const ShapeIndex& index) { + for (const LogicalBuffer* buffer : + assignment.GetSourceBuffers(instruction, index)) { + worklist.push_back(buffer); + } + }); }; for (HloInstruction* user : hlo.users()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 5244ac61e56307857aca659854647bd6c3e991d7..fe9eab93aae95557e3ee27a64c09b78f37ac2348 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -38,7 +38,8 @@ class AliasAnalysis { // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array); + llvm_ir::IrArray* array, + const ShapeIndex& index = {}); private: // Returns a unique alias domain for this emitter. diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 64b935bbf1fb9033cd2e1259b4639cd3780be711..e17c649e5272a9e7c0d5126083ab76542abfdf48 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,46 +99,101 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, for_body_generator); + } + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); + } - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - For(name, start, end, ir_builder_->getInt64(step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + for_body_generator); + } + + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -123,9 +202,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -183,7 +278,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 497b48ff227d7d1f158080529372df44b6932b24..9f867014fb015845448c4fcf9c165750f8a61935 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,7 +34,7 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) : prefix_(std::string(prefix)), suffix_(std::string(suffix)), @@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + ir_builder_->getInt64(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + ir_builder_->getInt64(stride), unroll_mode, prevent_vectorization); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index d915f95db134918a173a9711936bb1e2f1ea0d95..4e403cd994874c27453574283c5c573c876628db 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -182,34 +189,34 @@ class ForLoopNest { // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb..d18c9dee826eab5760d391bb8f7b5bd02ab659ae 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -193,6 +193,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // An Opaque is like a void*, use i8*. case OPAQUE: return llvm::Type::getInt8PtrTy(module->getContext()); + case TOKEN: + // Tokens do not have a physical representation, but the compiler needs + // some placeholder type, so use int8*. + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 1d9c9e0678765a779ec94e578e0e6f69d46b80de..296d04d4362b12fdc39798a016ca9e8795e02586 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9f9c7dcf788c24468cd474d9e7e20980135c1f0 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,358 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + reachability_ = computation_->ComputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + // + // An invariant is that no bitcast nodes will show up in the middle of a + // fusion node. This invariant must hold in order for us to lower it. Given + // that, we require that during multi-output fusion, a fusion node ending with + // bitcast to preserve its structure as a nested fusion instead being + // merged and flattened. + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + if (remaining->opcode() == HloOpcode::kFusion && + remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { + auto parent_computation = remaining->parent(); + // Create a nested fusion node. + auto remaining_nested_fused = + parent_computation->AddInstruction(HloInstruction::CreateFusion( + remaining->shape(), HloInstruction::FusionKind::kLoop, + remaining)); + TF_CHECK_OK(parent_computation->ReplaceInstruction( + remaining, remaining_nested_fused)); + remaining = remaining_nested_fused; + } + remaining->FuseInstructionIntoMultiOutput(fused); + } + + return remaining; +} + +bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion(computation_)) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( + HloComputation* /*computation*/) { + return false; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..d9c36fa284347d1efa16d8d3e45da807c3b8bf8b --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce memory reads. + virtual bool IsProfitableOperand(HloInstruction* instr); + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + + private: + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0f26a025bf125f70199637894741540f89eae7e5..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -155,20 +155,15 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, case HloOpcode::kConstant: { if (first_reshape_operand->opcode() == HloOpcode::kReshape) { VLOG(5) << "Adding reshape to kConstant operand"; - HloInstruction* reshape = computation->AddInstruction( + return computation->AddInstruction( HloInstruction::CreateReshape(new_shape, operand)); - operand->SetupDerivedInstruction(reshape); - return reshape; } else { CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); VLOG(5) << "Adding transpose to kConstant operand"; std::vector inverse_permutation = InversePermutation(first_reshape_operand->dimensions()); - HloInstruction* transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - new_shape, operand, inverse_permutation)); - operand->SetupDerivedInstruction(transpose); - return transpose; + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); } } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d01c35b99231310692f85d0f9fbf4f2c3709d44c..961158e677baa46465af3f1f9a62929d14547c30 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -348,8 +348,8 @@ StatusOr>> Service::BuildExecutables( module_protos[i]->entry_computation_name().c_str()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - hlo_snapshots.push_back(std::move(hlo_snapshot)); } + hlo_snapshots.push_back(std::move(hlo_snapshot)); } VLOG(1) << "Computations:"; @@ -721,6 +721,15 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, executable_ptrs.push_back(executable.get()); } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), + all_executors[i][0], + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + } + } + // Execute the generated executables in parallel and return the device // handles for each computation's output. ExecutionProfile profile; @@ -736,6 +745,18 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, *result->add_responses() = response; } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(outputs[i], 0)); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, all_executors[i][0], + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + // Dump out the ith snapshot. + TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + } + } + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; return Status::OK(); } @@ -835,6 +856,10 @@ StatusOr> Service::BuildExecutable( backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); + if (!execution_directory_path.empty()) { + executable->set_hlo_snapshot(std::move(hlo_snapshot)); + } + return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index d64b2b4d0afa15f8c0cf48b19c33e51a3d011eb0..8748a4c1447eca691abc0f7ca48feda48ceb86e1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,14 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -297,9 +295,6 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. std::unique_ptr execute_backend_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d624f548b1ba65e6f6dfd7b329e8c86ab29112a0..e25f5e67c719430c0e7a8e0bb059efdc01ea75f9 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,147 +44,18 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kClz: - return UNOP_CLZ; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kExpm1: - return UNOP_EXPM1; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kLog1p: - return UNOP_LOG1P; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); - } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s, but got %s.", +Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument("Expected array argument for %s, but got %s.", std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); - } else { - return Status::OK(); } + return Status::OK(); } Status VerifyReducerShape(const ProgramShape& reducer_shape, @@ -321,84 +192,80 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); -} + TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( "Expected element type in shape to be floating for floor/ceil " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_EXPM1: - case UNOP_LOG: - case UNOP_LOG1P: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be floating or complex for " "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be complex for real/imag " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_CLZ: - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + case HloOpcode::kSort: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating point for IsFinite " + "Expected element type in shape to be floating " + "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -415,8 +282,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const Shape* arg_shape = nullptr; PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); + TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; element_type = arg_shape->element_type(); @@ -463,6 +329,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferGenerateTokenShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -473,12 +350,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "Convert does not allow tuples, so cannot convert from %s to %s.", + "Convert does not allow non-arrays, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -495,7 +373,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. @@ -542,7 +421,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (ShapeUtil::IsTuple(operand_shape)) { + if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -681,8 +560,8 @@ Status ValidateDotDimensionNumbers( /* static */ StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( @@ -768,8 +647,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -787,7 +667,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - BinaryOperation_Name(operation).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -797,8 +677,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring @@ -899,18 +778,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -943,10 +819,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -955,51 +830,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, - broadcast_dimensions); -} - -/* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + TF_RETURN_IF_ERROR( + ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", + HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR( + ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( "Expected element type in shape to be floating for complex compose " @@ -1007,7 +875,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); @@ -1015,8 +883,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -1024,24 +892,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } @@ -1053,23 +921,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); -} - -/* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", - TernaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1086,18 +948,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); -} - -/* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); @@ -1106,7 +961,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return InvalidArgument("Unknown operation %s.", - VariadicOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1121,15 +976,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // All arguments must have the same shape. const Shape* arg_shape = arg_shapes[0]; for (size_t i = 1; i < arg_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); + TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } - if (!ShapeUtil::IsTuple(*arg_shapes[i]) && - !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; @@ -1212,11 +1064,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm training")); + ExpectArray(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1318,11 +1170,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm inference")); + ExpectArray(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1465,16 +1317,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64 feature_index) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad")); + ExpectArray(scale_shape, "scale input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - output_grad_shape, "output_grad input of batch norm grad")); + ExpectArray(output_grad_shape, "output_grad input of batch norm grad")); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); @@ -1623,8 +1472,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -1859,7 +1708,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + ExpectArray(*operand_shape, "operand of cross replica sum")); } if (operand_shapes.size() == 1) { return *operand_shapes[0]; @@ -1901,8 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, operand_shape.element_type())); return InferWindowOutputShape(operand_shape, window, @@ -1915,7 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter")); + ExpectArray(operand_shape, "operand of select-and-scatter")); // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { @@ -1980,7 +1828,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); }; - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); + TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), @@ -2039,10 +1887,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, tensorflow::gtl::ArraySlice slice_sizes) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape, - "start indices of dynamic slice")); + ExpectArray(start_indices_shape, "start indices of dynamic slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", @@ -2100,11 +1947,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& update_shape, const Shape& start_indices_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice")); + ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - start_indices_shape, "start indices of dynamic update slice")); + ExpectArray(update_shape, "update of dynamic update slice")); + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); VLOG(2) << tensorflow::strings::Printf( "updating slice of shape %s at dynamic start_indices %s with update " @@ -2172,8 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*static */ StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reverse")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); } @@ -2303,7 +2149,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { return InvalidArgument("Broadcast with negative dimension size %lld.", @@ -2322,7 +2168,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); @@ -2354,7 +2200,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTransposeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); std::iota(indices.begin(), indices.end(), 0); @@ -2375,9 +2221,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // "degenerate" cases, as with binary elementwise ops. /* static */ StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", @@ -2576,9 +2422,9 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR( + ExpectArray(gather_indices_shape, "gather indices operand of gather op")); if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b4177f08ece8daabaf2922ddd7e947a1b..eef6e62fc8d933452ebc3f9a5b8bc49828455be5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,8 +46,6 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); static StatusOr InferUnaryOpShape(HloOpcode opcode, const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, @@ -55,9 +53,6 @@ class ShapeInference { // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); @@ -67,9 +62,6 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, - const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs); @@ -80,9 +72,6 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); @@ -227,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr InferGenerateTokenShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -279,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -295,7 +291,7 @@ class ShapeInference { // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -303,8 +299,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0e61994a786b53a295ef9c9c2287b28fbf754d9b..bafe14d6f45f851924c37908d4c93bbff2dac459 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); @@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); @@ -1315,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( @@ -1391,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("Expected non-tuple argument")); + HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1690,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for input")) + HasSubstr("Expected array argument for input")) << statusor.status(); } @@ -1704,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for gather indices")) + HasSubstr("Expected array argument for gather indices")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index ba16dc640e2d2974eab4fc8b134a6e33c03e3b85..49e1f873192f800056a2272f7d4f698898b0f8a1 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -178,7 +178,6 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - convolution.SetupDerivedInstruction(new_conv.get()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 3139801ea3130324f48d728dc6f739f709e55911..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index bb634e6573ffceeaa66e0ac9141fe7e3a39ed602..eb6d1ada6b553f998fe06917dfdf0b5092cd79cd 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -723,15 +723,16 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -789,8 +790,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return param_uses.size() == 1 && param_uses[0].first == callee_root && callee_root->IsElementwiseOnOperand(param_uses[0].second); } - // Check if 'user' is element-wise. - return user->IsElementwise(); + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index f558316b05b168a6f100e8ef69adfd9dbc023102..5734f284071944bc22011405898cf86f33dc48d7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1148,5 +1148,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index d668855084a884518b338cdf396a9330b9f43a2b..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -69,7 +76,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; - HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -79,17 +85,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (first_gte == nullptr) { - first_gte = operand; - } else if (!first_gte->has_compatible_sharding(operand)) { - can_simplify = false; - break; - } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape()) || - !instruction->has_compatible_sharding(top_tuple)) { + instruction->shape())) { can_simplify = false; break; } @@ -118,14 +117,12 @@ StatusOr TupleSimplifier::Run(HloModule* module) { HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - if (instruction->has_compatible_sharding(element_source)) { - changed = true; - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); - } + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); } } } diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index aa40b5cb264803097f52966d6f61f1f41b6b3017..44b0ec5cd4c1d406467007fcc530e919d602c438 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -32,11 +32,11 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->HasSideEffect() || - ShapeUtil::IsTuple(instruction->shape())) { + !ShapeUtil::IsArray(instruction->shape())) { continue; } if (comp->IsRemovable(instruction) && - ShapeUtil::HasZeroElements(instruction->shape())) { + ShapeUtil::IsZeroElementArray(instruction->shape())) { TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant( Literal::CreateFromShape(instruction->shape())))); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 5b14953ebb243da7b9be6eafd46160db8bc62707..18e54d23c241ae0d4c61d8be79ff021dfb02a3e6 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -47,6 +47,9 @@ struct ShapeTreeNode { // Children of this node, as indices into the container's nodes_ array. std::vector children; + // Tells whether this is a leaf node. + bool is_leaf = true; + explicit ShapeTreeNode(ShapeIndex index) : ShapeTreeNode(std::move(index), T()) {} ShapeTreeNode(ShapeIndex index, T data) @@ -122,9 +125,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index)->children.empty(); - } + bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -311,16 +312,14 @@ class ShapeTreeIterator : nodes_(nodes), node_(std::move(node)), iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } } ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } return *this; @@ -333,8 +332,7 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { --node_; } return *this; @@ -358,7 +356,7 @@ class ShapeTreeIterator ContainerType* nodes_; IteratorType node_; // True if we should not include interior nodes in our walk. - bool iterate_leaves_only_; + const bool iterate_leaves_only_; }; template @@ -379,6 +377,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { @@ -395,6 +394,7 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index dc5facf1581c07fbb74dfcee95025692938632bd..51de82e95746281ed6e587b545dc933b48ce1ad4 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -116,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) { TestInitValueConstructor(nested_tuple_shape_, 10); } +TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) { + ShapeTree shape_tree{ShapeUtil::MakeTupleShape({})}; + EXPECT_EQ(0, shape_tree.leaf_count()); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ce4d0079ee5eb28444509c712ec1a34037dc244a..c85fb20e01c1c8b7a8fc0d2b10881e5f9feed977 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -363,7 +363,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); + return IsEmptyTuple(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -413,8 +413,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( std::multiplies()); } -/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { - return ElementsIn(shape) == 0; +/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { @@ -645,15 +645,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); - } else { - // Opaque, token, etc types are vacuously compatible. - return true; - } + return CompareShapes(lhs, rhs, /*compare_layouts=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -903,6 +895,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return *return_shape; } +/* static */ StatusOr ShapeUtil::TryGetSubshape( + const Shape& shape, ShapeIndexView index) { + const Shape* return_shape = &shape; + for (auto i : index) { + if (!IsTuple(*return_shape) || i < 0 || + i >= return_shape->tuple_shapes_size()) { + return InvalidArgument( + "Shape index %s not a valid subshape index for tuple with shape %s", + index.ToString().c_str(), shape.DebugString().c_str()); + } + return_shape = &return_shape->tuple_shapes(i); + } + return return_shape; +} + /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, ShapeIndexView index) { Shape* return_shape = shape; @@ -939,68 +946,6 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return leaves; } -/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { - CHECK(IsArray(shape)); - - std::vector dimension_sizes; - std::vector degenerate_dimensions; - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) == 1) { - degenerate_dimensions.push_back(i); - } else { - dimension_sizes.push_back(shape.dimensions(i)); - } - } - - // Construct minor_to_major of stripped shape. The order of the non-degenerate - // dimensions should be preserved from the original shape. First, create - // vector of the non-degenerate dimensions from the original minor_to_major - // array. - std::vector minor_to_major; - for (int64 i : shape.layout().minor_to_major()) { - if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), - i) == degenerate_dimensions.end()) { - minor_to_major.push_back(i); - } - } - - // The dimensions in minor_to_major need to be renumbered to account for the - // degenerate dimensions which have removed. Decrement each dimension number - // once for each degenerate dimension which has a smaller number. - for (int i = 0; i < minor_to_major.size(); ++i) { - int adjustment = 0; - for (int64 dim : degenerate_dimensions) { - if (minor_to_major[i] > dim) { - adjustment++; - } - } - minor_to_major[i] -= adjustment; - } - - { - std::vector dims(minor_to_major.size()); - std::iota(dims.begin(), dims.end(), 0); - DCHECK(minor_to_major.size() == dims.size() && - std::is_permutation(minor_to_major.begin(), minor_to_major.end(), - dims.begin())); - } - Shape stripped_shape; - if (LayoutUtil::IsDenseArray(shape)) { - stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, - minor_to_major); - } else if (LayoutUtil::IsSparseArray(shape)) { - stripped_shape = - MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, - shape.layout().max_sparse_elements()); - } else { - stripped_shape = MakeShape(shape.element_type(), dimension_sizes); - } - - VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); - VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); - return stripped_shape; -} - namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 3853ada6ba65dbb1ac0754bcf753b4553ec260e7..8ee3f490a0837ec363758f6c633d73aa57687db4 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -62,6 +62,8 @@ class ShapeIndex { public: ShapeIndex() = default; ShapeIndex(std::initializer_list init) : indices_(init) {} + template + ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } @@ -132,6 +134,7 @@ class ShapeIndexView { ++new_begin; return ShapeIndexView(new_begin, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } bool operator==(const ShapeIndexView& other) const; bool operator!=(const ShapeIndexView& other) const; @@ -172,8 +175,8 @@ class ShapeUtil { // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); - // Returns true if 'shape' has zero elements. - static bool HasZeroElements(const Shape& shape); + // Returns true if 'shape' is an array with zero elements. + static bool IsZeroElementArray(const Shape& shape); // Returns the number of bytes required for an allocation of shape. The // |pointer_size| parameter is used for calculating the size of tuple @@ -333,7 +336,7 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); - // Returns an empty tuple shape. Can be used to indicate side-effects. + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return MakeTupleShape({}); } // Checks whether the shape is initialized. @@ -443,7 +446,7 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is an empty tuple, or is an array with no elements. + // Returns true if shape is the nil shape (an empty tuple). static bool IsNil(const Shape& shape); // Returns the number of elements in the given tuple shape. @@ -473,8 +476,11 @@ class ShapeUtil { static bool IndexIsValid(const Shape& shape, ShapeIndexView index); // GetSubshape and GetMutableSubshape return a particular nested Shape within - // the given Shape argument. + // the given Shape argument. The non-Try variants check fail if index is + // invalid. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); + static StatusOr TryGetSubshape(const Shape& shape, + ShapeIndexView index); static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); // Returns whether the given index in the given shape is a leaf element of the @@ -510,26 +516,6 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); - // Removes all degenerate dimensions (size one) from the given shape. The - // stripped minor_to_major preserves the relative ordering of non-degenerate - // dimensions. The stripped shape has the property that the underlying - // representation (bits in memory) for the stripped shape is the same as the - // original shape modulo padding. Examples: - // - // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} - // stripped shape: F32 [2], minor_to_major = {0} - // - // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} - // stripped shape: F32 [6, 5], minor_to_major = {1, 0} - // - // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} - // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} - // - // input shape: F32 [1, 1], minor_to_major = {0, 1} - // stripped shape: F32 [], minor_to_major = {} - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) - static Shape StripDegenerateDimensions(const Shape& shape); - // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, @@ -714,7 +700,7 @@ class ShapeUtil { tensorflow::gtl::ArraySlice incr, const FnType& visitor_function, bool parallel = false) { - if (ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index ecdb6532f1d743c7dacc266eeba615e19748ee27..61aa198e524373f84b7e950d5835dd2457c88a62 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -172,6 +172,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) { ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); } +TEST(ShapeUtilTest, TokenCompatibility) { + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Compatible( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}))); +} + +TEST(ShapeUtilTest, TokensEqualShapes) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}))); + EXPECT_FALSE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})}))); +} + TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); auto layout_1 = shape_1.mutable_layout(); @@ -329,6 +364,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) { EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); } +TEST(ShapeUtilTest, NilShape) { + EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); +} + TEST(ShapeUtilTest, NestedTuple) { EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); EXPECT_FALSE(ShapeUtil::IsNestedTuple( @@ -359,25 +404,30 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } -TEST(ShapeUtilTest, HasZeroElements) { - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); - EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); +TEST(ShapeUtilTest, IsZeroElementArray) { + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); + EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17}))); + + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})}))); } TEST(ShapeUtilTest, SameDimensions) { @@ -742,16 +792,6 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } -TEST(ShapeUtilTest, StripDegenerateDimensions) { - EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShape(F32, {3, 1, 2})), - ShapeUtil::MakeShape(F32, {3, 2}))); - EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); -} - TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7f6bbe6f879fd9596601f99f034a0391a71c52f8..e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1203,6 +1203,22 @@ xla_test( ], ) +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "call_test", srcs = ["call_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 36a706496918ac8c15780473019e2a8d098ffa22..c3a289ee09cc1ee7b9d705a38c26a3ac7a8a6aa2 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2758,7 +2758,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + "Expected array argument for lhs of binary operation")); } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 34c86e007beea1cbac04641bdbdab62dc567f13e..3a0f51fc66d65c8684bd607b9e8103559cd4d8d4 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -671,7 +671,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -684,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf..352864502a184237fde600330836fe471a5444f2 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -417,7 +417,22 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - HasSubstr("Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); +} + +// Show that we can't concatenate with tokens. +XLA_TEST_F(ConcatTest, CannotConcatTokens) { + XlaBuilder builder(TestName()); + auto token_shape = ShapeUtil::MakeTokenShape(); + auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); + auto x = builder.Parameter(0, r1f32, "x"); + auto y = builder.Parameter(1, token_shape, "y"); + builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_THAT( + computation_status.status().ToString(), + HasSubstr("Expected array argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 947959beb144e1509a77ad2f94b8493de46ba6f2..346bb3a3996ee5bf662b0f74dd0c2096efbf5295 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 08ed826c80823efe0af8ce682945fe7e46d267ae..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -94,8 +94,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, VersionedComputationHandle(), - GetModuleConfigForTest()); + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index eb3a2ea76a667a2afa2562f01d28f34384b84a21..249da87f489324ed9d377cc46a15cef5a9e74192 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index c8a05c2e9e971d86feb6ff893fcd25c6767af99f..22c664d1426c598dbb695ff1b66ce009b0a19c00 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); - VerifyModule(); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40..5b59cc77f61b05092d3afb331e73932c9edc5840 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 2f46ee0be216d7dabf1c476d3cfb7d528f8ab6a4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7df45bebebdd3eb2e71f27d831a8e2ac9e3b5f7c..3975e9125703ee081d4e84fa8bd27fcbe483ac34 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -488,10 +488,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ef54e6f89251bbd6dba0705698c6627c554791e --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { + // Thread a token around a while loop. Token is created and consumed by a + // GenerateToken instruction in the while body. + string module_string = R"( +HloModule TokenInWhileLoop + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %generate-token = token[] generate-token(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokenInWhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] generate-token() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + EXPECT_TRUE(RunAndCompare(module_string, error_spec_)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index ff5340ee3fac51288eef43962ac6427cab64bc54..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index be094b7890aab08c55686c4785e01ff2ffba7cc2..f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -24,6 +24,9 @@ limitations under the License. // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -43,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -195,25 +199,45 @@ StatusOr ReplayComputation(const HloSnapshot& module, return std::move(*result_literal); } +StatusOr ParseInputFile(const string& filename, + const Options& opts) { + tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); + + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} + int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); - tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg); - status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo()); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg, - status.ToString().c_str()); - continue; - } - CHECK(opts.use_fake_data) - << "HloProto input must be handled with --use_fake_data"; + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { + continue; } - + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 53ba120d21a9e16904d8c709617fc0eda6be63c4..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -225,14 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -291,11 +283,6 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; - DebugOptions debug_options = 2; -} - message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; @@ -305,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -327,24 +306,6 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; -} - message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -353,10 +314,6 @@ message ExecuteGraphRequest { ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; -} - message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } @@ -370,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -394,31 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; -} - message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -460,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -476,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 6bdfb0179cd6a5e4eaee20cd877bd976e0e173c3..0af73e8a93060f4569ddef9697b89a6fa2b8674b 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -276,12 +276,6 @@ message ExecutionProfile { int64 compute_and_transfer_time_ns = 5; } -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; -} - // Handle given to a user that represents an execution that the user launched // asynchronously on the device. message ExecutionHandle { @@ -295,13 +289,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -441,44 +428,6 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} - -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; -} - message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -516,13 +465,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -531,56 +473,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -592,297 +484,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; - - // Elementwise, computes clz(x). - UNOP_CLZ = 17; - - // Elementwise, computes exp(x)-1. - UNOP_EXPM1 = 18; - - // Elementwise, computes log(x+1). - UNOP_LOG1P = 19; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -897,67 +498,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -988,59 +528,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md new file mode 100644 index 0000000000000000000000000000000000000000..d8b1cb7616ac348981bf2b69d6e2fd8d8a6e6b78 --- /dev/null +++ b/tensorflow/contrib/autograph/LIMITATIONS.md @@ -0,0 +1,50 @@ +# Capabilities and Limitations + +TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`. + +Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support. + +# Python Language Support Status + +Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved. + + Construct | Supported now? | Plan to support? | Notes + :--------- | :--------------: | :----------------: | :----- +If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error. +For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations. +While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations. +Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests. +Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested. +Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`. +Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so. +Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists. +Static function calls | Yes | | Non-recursive function calls +Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion. +Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant. +Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html). +List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation. +Function variables | Yes | | e.g. `f_new = f_orig; f_new()` +Lambda functions | No | Yes | Planned feature. +Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods. +Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported. +Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case. +Dynamic code / exec | No | | +Reflection | No | | +Try / Except | No | No | No current sane TF equivalent. +Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code. +Functions with side effects | Some | | Side effects are allowed, under certain circumstances. +Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples. +List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority. +Custom context managers | No | Yes | Currently low priority. Left unconverted currently. +Generators | No | Maybe | Could be achievable using queues; very low priority. +Assertions | Yes | | As `tf.Assert` +Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in. +Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority. +Async | No | No | + +## Extra capabilities + + - We liberally add name scopes to generated functions + - Operations get decent default names everywhere (planned) + - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially. + diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 674859bed4ec157d5d5b33b6fc015c930e54b392..829a57d8e61ee4a41076f7397488cd85bdca1376 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -120,3 +120,15 @@ You can use the functional API to inspect the generated code as well: print(ag.to_code(f)) # Output: ``` + +## Filing bugs and feature requests + +### Reporting a bug + + - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you. + +### Requesting a feature + +If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md index 866e5f583a34570dfddc733f57561ed1d2b7c5bf..7e6b0cc27dd1cf8c0f459a0a34f98092728342a2 100644 --- a/tensorflow/contrib/autograph/STYLE_GUIDE.md +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -20,7 +20,17 @@ Naming conventions: Below are AutoGraph-specific conventions. In the event of conflict, it supercedes all previous conventions. -1. __Citations in Docstrings.__ Write a `#### References` subsection at the +1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/] + notation to describe the type for args, return values and attributes. + + Example: + + ``` + Args: + foo: Dict[str, List[int]], a dictionary of sorts + ``` + +2. __Citations in Docstrings.__ Write a `#### References` subsection at the bottom of any docstring with citations. Use ICLR’s bibliography style to write references; for example, order entries by the first author's last name. Add a link to the paper if the publication is open source (ideally, @@ -60,12 +70,12 @@ it supercedes all previous conventions. https://arxiv.org/abs/1803.04386 ``` -2. Avoid LaTeX in docstrings. +3. Avoid LaTeX in docstrings. * It is not rendered in many (if not most) editors and can be hard to read for both LaTeX experts and non-experts. -3. Write docstring and comment math using ASCII friendly notation; python using +4. Write docstring and comment math using ASCII friendly notation; python using operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 79d73af98097aea418f2116aee40b2572b418ef7..dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -30,6 +30,8 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.directives import set_element_type +from tensorflow.contrib.autograph.impl.directives import set_loop_options from tensorflow.contrib.autograph.impl.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented @@ -42,8 +44,11 @@ _allowed_symbols = [ 'do_not_convert', 'to_code', 'to_graph', - # Special functions and overloaded operators + # Overloaded operators 'operators', + # Special functions and directives + 'set_element_type', + 'set_loop_options', 'stack', # Exceptions 'AutographParseError', diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 8f9bffa55e44e4942bb3845945b3d440c7957cc9..284ad84be566199adaaa1ab641d37528ae4dfd2d 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -31,6 +31,7 @@ py_library( "name_scopes.py", "side_effect_guards.py", "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -208,3 +209,14 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index d7ddbe8a04f64848d6ec21155d8d85f60e19d276..1e718f02d10ea1a520066c74f520144feee242b9 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -46,7 +46,7 @@ class SymbolNamer(object): class ControlFlowTransformer(transformer.Base): - """Transforms control flow structures like loops an conditionals.""" + """Transforms control flow structures like loops and conditionals.""" def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index b49521b2c328f418828a5e92890aa1b169384b70..c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -33,82 +33,193 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.python.framework import dtypes +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +POP_USES = 'pop_uses' class ListTransformer(transformer.Base): """Converts lists and related operations to their TF counterpart.""" - def _empty_list(self, node): - if not anno.hasanno(node, 'element_type'): - raise NotImplementedError( - 'type inference for empty lists is not yet supported; ' - 'use set_element_type(, ) to continue') - dtype = anno.getanno(node, 'element_type') - if not isinstance(dtype, dtypes.DType): - # TODO(mdan): Allow non-TF dtypes? - # That would be consistent with the dynamic dispatch pattern, but - # we must make sure that doesn't become confusing. - raise NotImplementedError('element type "%s" not yet supported' % dtype) - - dtype_name = dtype.name - # TODO(mdan): Does it ever make sense not to use tensor lists? + def visit_List(self, node): + node = self.generic_visit(node) template = """ - tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + ag__.new_list(elements) """ - return templates.replace_as_expression(template, dtype_name=dtype_name) + return templates.replace_as_expression(template, elements=node) - def _pre_populated_list(self, node): - raise NotImplementedError('pre-populated lists') + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) - def visit_Expr(self, node): + def visit_Call(self, node): node = self.generic_visit(node) - if isinstance(node.value, gast.Call): - call_node = node.value - - if not anno.hasanno(call_node.func, anno.Basic.QN): - return node - qn = anno.getanno(call_node.func, anno.Basic.QN) - - if qn.qn[-1] == 'append' and (len(call_node.args) == 1): - template = """ - target = ag__.utils.dynamic_list_append(target, element) - """ - node = templates.replace( - template, - target=qn.parent.ast(), - element=call_node.args[0]) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + return node - def _replace_list_constructors(self, targets, values): - for target in targets: - if (isinstance(target, (gast.Tuple, gast.List)) and - isinstance(values, (gast.Tuple, gast.List))): - n_targets = len(target.elts) - for i in range(n_targets): - target_el, value_el = target.elts[i], values.elts[i] - values.elts[i] = self._replace_list_constructors( - (target_el,), value_el) - return values - if isinstance(values, gast.List): - if values.elts: - return self._pre_populated_list(values) - else: - return self._empty_list(values) - return values - - def visit_Assign(self, node): - node = self.generic_visit(node) + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node - # Only convert lists when they are assigned to a variable, e.g.: - # l = [] - # TODO(mdan): A similar pattern exists in type_info.py - # We should add a generic "unpack_assignment" function to the base - # transformer, that has the same effect as applying some logic to the SSA - # form. - node.value = self._replace_list_constructors(node.targets, node.value) + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) return node diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 74c6dc64f197f75eb3e66c01fb078467e8e8ea89..9f18ab9f44dd8c3f341a02b950f75317c676eff8 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -22,74 +22,126 @@ from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists from tensorflow.python.framework import dtypes -from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.platform import test class ListTest(converter_test_base.TestCase): - def test_empty_annotated_list(self): + def test_empty_list(self): def test_fn(): - l = [] - utils.set_element_type(l, dtypes.int32) - l.append(1) - return l + return [] - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: - # TODO(mdan): Attach these additional modules automatically. - result.utils = utils - result.dtypes = dtypes + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: with self.test_session() as sess: - self.assertAllEqual([1], sess.run(result.test_fn().stack())) + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) - def test_empty_annotated_lists_unpacked(self): + def test_list_append(self): def test_fn(): - l, m = [], [] - utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m + l = [1] + l.append(2) + l.append(3) + return l - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): - def test_empty_annotated_lists_list_unpacked(self): + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. def test_fn(): - [l, m] = [], [] + l = [1, 2, 3] utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m - - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node, array_ops.stack, dtypes.int32) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeda9c4164eb70329bd50f789eea5441c8fc87 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class SliceTransformer(transformer.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, context): + return SliceTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_test_base.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 91ae0b9b82c6f649c3c80b91ef894b2221cdc962..02f16ae1875d6bd1fb87d19f8bfc5cae900391dd 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -20,6 +20,7 @@ py_library( "api.py", "config.py", "conversion.py", + "directives.py", "naming.py", "special_functions.py", ], diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 55a30dc127957b2a9caa053db843380c94bacfbf..7802bbbe27ec5fed891440af2f589801918b3bdd 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -38,6 +38,7 @@ from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.impl import naming from tensorflow.contrib.autograph.pyct import ast_util @@ -371,6 +372,8 @@ def node_to_graph(node, ctx, nocompile_decorators): # TODO(mdan): Clean this up. # Some intermediate analyses are not required, and some comments got orphaned. + # TODO(mdan): We may assume all converters require analysis to be re-done. + # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? @@ -393,6 +396,8 @@ def node_to_graph(node, ctx, nocompile_decorators): node = _static_analysis_pass(node, ctx) node = lists.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = slices.transform(node, ctx) node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..77a2184e229003a3403cbe3bf116ad2570274a1b --- /dev/null +++ b/tensorflow/contrib/autograph/lang/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "lang", + srcs = [ + "directives.py", + "special_functions.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/operators", + ], +) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":lang", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/contrib/autograph/lang/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..11135295a7966bc5d693676fcc71fe43791f2e99 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/special_functions.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None, strict=True): + """Stacks the input, if it admits the notion of stacking. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any + element_dtype: tf.DType, optional dtypedtype for the elements in the list. + Required if the input is stackable, and the list is untyped. + strict: bool, if True an error is raised if the input is not stackable. + Otherwise the function is a no-op. + + Returns: + Any, if the input is stackable, the result will be a tf.Tensor. Otherwise, + if strict=False, the result will be list_or_tensor. + + Raises: + ValueError: if strict=True and the input is not stackable. + """ + if strict: + def raise_error(x): + raise ValueError('%s must be stackable when strict=True' % x) + original_call = raise_error + else: + original_call = lambda x: x + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=original_call)) diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a49cb6407517b634e0f1259fccda03d4ed18e83f --- /dev/null +++ b/tensorflow/contrib/autograph/lang/special_functions_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.lang import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1, strict=False), 1) + self.assertListEqual( + special_functions.stack([1, 2, 3], strict=False), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)], strict=False), list)) + + with self.assertRaises(ValueError): + special_functions.stack([1, 2, 3]) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 671c9ccc13eaa887522cfc248a6d56d7ab9719ca..988df70157170ed0a9ece33976e871e6f7693bbc 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state): Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. - An additionnal loop condition. + An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index ad97fdfa8e78d1fd4c38724612d83519c6609cce..ce746feeacf373874f9852d430eb37fadaf1e89e 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -286,7 +286,7 @@ class Forward(object): # TODO(alexbw): see if we can simplify by visiting breadth-first def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagtion.""" + """Depth-first walking the CFG, applying dataflow information propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index d6555dc7e0b3d49b3befa7326b28387509c83006..7d1e65c958d7787ef5ed707d4822d14a83092975 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -17,8 +17,8 @@ This analyzer uses known live values to further infer object types. This may include for instance constructed objects and object member functions. -In addition, the analyzer will also process annotations for TF (staged) type -annotations. +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). Requires annotations generated by LiveValuesResolver. """ @@ -44,6 +44,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -159,12 +160,10 @@ class TypeInfoResolver(transformer.Base): # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) - if anno.hasanno(definition, 'type'): - anno.setanno(node, 'type', anno.getanno(definition, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) - if anno.hasanno(definition, 'element_type'): - anno.setanno(node, 'element_type', - anno.getanno(definition, 'element_type')) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node def _process_variable_assignment(self, target, value): @@ -211,23 +210,20 @@ class TypeInfoResolver(transformer.Base): if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): - if len(node.args) != 2: - raise ValueError('"%s" must have exactly two parameters' + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) - target_arg, type_arg = node.args + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') + else: + target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) - if isinstance(type_arg, gast.Str): - element_type = type_arg.s - elif isinstance(type_arg, gast.Num): - element_type = type_arg.n - else: - if not anno.hasanno(type_arg, 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - element_type = anno.getanno(type_arg, 'live_val') + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given @@ -235,7 +231,9 @@ class TypeInfoResolver(transformer.Base): # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 95cbf5ca79a5045f5e050b735390dcfb668b5bb2..484562f294bb53a63feeca965b8f94c58aa2a685 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -187,14 +187,14 @@ class TypeInfoResolverTest(test.TestCase): def test_fn(): f = [] - f = utils.set_element_type(f, Foo) + f = utils.set_element_type(f, Foo, (1, 2, 3)) return f node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') def test_type_annotation_args(self): @@ -207,7 +207,7 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') def test_nested_unpacking(self): @@ -223,9 +223,9 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts - self.assertEquals(Foo, anno.getanno(a, 'type')) - self.assertEquals(Bar, anno.getanno(b, 'type')) - self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) @@ -242,8 +242,8 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'utils': utils}) a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(1, anno.getanno(a, 'element_type')) - self.assertEquals(2, anno.getanno(b, 'element_type')) + self.assertEquals(anno.getanno(a, 'element_type').n, 1) + self.assertEquals(anno.getanno(b, 'element_type').n, 2) self.assertFalse(anno.hasanno(a, 'type')) self.assertFalse(anno.hasanno(b, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 60bca8b38dcf62b4e997379d075cfc45511a894f..a656e99d21c6d3a1af831d3b34cf135b03c7ba29 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -191,7 +191,7 @@ class Base(gast.NodeTransformer): # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. def apply_to_single_assignments(self, targets, values, apply_fn): - """Applies a fuction to each individual assignment. + """Applies a function to each individual assignment. This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. It tries to break down the unpacking if possible. In effect, it has the same @@ -219,7 +219,7 @@ class Base(gast.NodeTransformer): targets field of an ast.Assign node. values: an AST node. apply_fn: a function of a single argument, which will be called with the - respective nodes of each single assignment. The signaure is + respective nodes of each single assignment. The signature is apply_fn(target, value), no return value. """ if not isinstance(targets, (list, tuple)): diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..012a51f71101471850d312033c41dcbc4805d44c 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,74 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + for inp in computation.captured_inputs: + print("inp: %s" % inp) + for op in inp.consumers(): + print("op: %s" % op) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index ea8339334f9b5e58a35dc9edf314a220e4c9868c..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -188,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 758754feac31f1d2cf10e69d7a9a6d288931c900..911d87fa10570382ee5f03edfc1bfd1d116c8360 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -232,7 +232,13 @@ def _dnn_tree_combined_model_fn(features, return update_op if predict_with_tree_only: - tree_train_logits = tree_logits + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) else: tree_train_logits = dnn_logits + tree_logits diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8ae493ba998bd882b5ef946f927ec1882d91f61d..9aa4614967958247dde5d81b862baaafd8d4144a 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,9 +16,11 @@ Visualization and inspection: @@dot_graph_from_checkpoint +@@list_objects @@object_metadata Managing dependencies: +@@capture_dependencies @@Checkpointable @@CheckpointableObjectGraph @@NoDependency @@ -42,6 +44,8 @@ from tensorflow.python.training.checkpointable.base import Checkpointable from tensorflow.python.training.checkpointable.base import NoDependency from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.util import capture_dependencies +from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 42ba368531468b789a87429f88ca84937f9b909d..1a7a3759baa4a5559b4b70ff4f7467c41da9111f 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -74,3 +74,14 @@ tf_py_test( ], tags = ["manual"], ) + +tf_py_test( + name = "gcs_config_ops_test", + size = "small", + srcs = ["python/ops/gcs_config_ops_test.py"], + additional_deps = [ + ":cloud_py", + "//tensorflow/python:client_testlib", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index a6e13ea3ae938444b9ead0772e52fb8797a847da..ef7aa7624ce7b9b6480c4d088a2fb7678a7acc76 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -27,8 +27,9 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', - 'ConfigureColabSession', - 'ConfigureGcs', + 'BlockCacheParams', + 'configure_colab_session', + 'configure_gcs', 'ConfigureGcsHook', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 40160706f70e8fa8323005dd183770ed51c8c415..1311063ec023bdaa2588d6f1c826bf900f7dea09 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -79,6 +79,7 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0c9948129116ac371c64fc01a96ecc6194e244 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -0,0 +1,34 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gcs_config_ops +from tensorflow.python.platform import test + + +class GcsConfigOpsTest(test.TestCase): + + def testSetBlockCache(self): + cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) + with self.test_session() as sess: + gcs_config_ops.configure_gcs(sess, block_cache=cfg) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a5a9630a4aa382fb13d8fc88e575e094e575cc87..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,7 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' @@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] @staticmethod def _envVarFallback(): @@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver): # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: - tpu = self._gkeMaster() + tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() @@ -214,7 +215,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -256,6 +257,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) @@ -276,8 +281,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5fac55fd027fa2d100621e08a09e05cdb3a1b941..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { @@ -358,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - def testGkeEnvironment(self): + def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeMaster())) + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] def testDiscoveryUrl(self): diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 0708d6b7b9f0ba549aea091a265f42890e50d223..e524e9e7437b19e0d117fe7b85042e8154773a02 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..5c5adaf5798289fba1c5d0b3f9e0489dc242043e 100644 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -16,15 +16,15 @@ include (ExternalProject) set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) +set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8) set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) set(double_conversion_INCLUDES ${double_conversion_BUILD}) if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib) else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a) endif() set(double_conversion_HEADERS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 693dc7cd673233b889b35a3f3170b57581da9a9f..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,12 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) @@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -59,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1959ad028a06f3c1ff6a658d656155541891fd13..92446044892127284ecb8753a250b77cb2a5743a 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -756,6 +756,8 @@ add_custom_command( "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" "${api_init_list_file}" COMMENT "Generating __init__.py files for Python API." @@ -765,7 +767,49 @@ add_custom_command( add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -776,6 +820,7 @@ add_dependencies(tf_python_build_pip_package tf_python_touchup_modules tf_python_ops tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4..c8de8db126f7724386be565aa524b4b527976730 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -325,6 +325,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 # Broken io_utils_test "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 + # OOM + "${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # b/110210559 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index c65a150464efc1e77419040f66f36fc6756325aa..cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -46,7 +46,7 @@ document. Imagine that we want to constrain the recall of a binary classifier to be at least 90%. Since the recall is proportional to the number of true positive classifications, which itself is a sum of indicator functions, this constraint -is non-differentible, and therefore cannot be used in a problem that will be +is non-differentiable, and therefore cannot be used in a problem that will be optimized using a (stochastic) gradient-based algorithm. For this and similar problems, TFCO supports so-called *proxy constraints*, diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py index 9ffad9caa92d2d3be8f598758a443b0eceb8d4d8..90371cd8d70db11dc77af02a2b1fd2a90f3dcf44 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2.py +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -44,11 +44,34 @@ from tensorflow.python.util import compat def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" + if not name: + name = "cond" + with ops.name_scope(name) as scope: - true_graph = function.func_graph_from_py_func(true_fn, [], [], - name="%s_true" % scope) - false_graph = function.func_graph_from_py_func(false_fn, [], [], - name="%s_false" % scope) + # Identify if there is a caller device, & get the innermost if possible. + device_stack = ops.get_default_graph()._device_function_stack + caller_device = device_stack[-1] if device_stack else None + + caller_colocation_stack = ops.get_default_graph()._colocation_stack + caller_container = ops.get_default_graph()._container + caller_collection_ref = ops.get_default_graph()._collections + + func_name_prefix = scope.replace("/", "_") + + true_graph = function.func_graph_from_py_func( + true_fn, [], [], + name="%strue" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) + false_graph = function.func_graph_from_py_func( + false_fn, [], [], + name="%sfalse" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py index dcecefb520ee4bee276f1682f6a90550ffa7e547..94ed3e130ba06c129d96c4ea775a043b5bc9b3ea 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2_test.py +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -25,10 +25,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver +from tensorflow.python.util import compat class NewCondTest(test.TestCase): @@ -81,6 +84,52 @@ class NewCondTest(test.TestCase): self._testCond(true_fn, false_fn, [x, y]) self._testCond(true_fn, false_fn, [y]) + def testNoInputs(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + with self.test_session() as sess: + self.assertEqual(sess.run(out, {pred: True}), [1.0]) + self.assertEqual(sess.run(out, {pred: False}), [2.0]) + + def _createCond(self, name): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(1.0, name="x") + + def true_fn(): + return x + + def false_fn(): + return x + 1 + + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op + + def testDefaultName(self): + with ops.Graph().as_default(): + cond = self._createCond(None) + self.assertEqual(cond.name, "cond") + self.assertIn("cond_true", ops.get_default_graph()._functions) + self.assertIn("cond_false", ops.get_default_graph()._functions) + + with ops.Graph().as_default(): + with ops.name_scope("foo"): + cond = self._createCond("") + self.assertEqual(cond.name, "foo/cond") + self.assertIn("foo_cond_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_false", ops.get_default_graph()._functions) + + cond2 = self._createCond(None) + self.assertEqual(cond2.name, "foo/cond_1") + self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) + def testSecondDerivative(self): pred = array_ops.placeholder(dtypes.bool, name="pred") x = constant_op.constant(3.0, name="x") @@ -152,5 +201,225 @@ class NewCondTest(test.TestCase): self.assertEqual(false_val, [0.0]) +class CondV2CollectionTest(test.TestCase): + + def testCollectionIntValueAccessInCond(self): + """Read values from graph collections inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = 2 + y = 5 + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + def fn(): + x_const = constant_op.constant(ops.get_collection("x")[0]) + y_const = constant_op.constant(ops.get_collection("y")[0]) + return math_ops.add(x_const, y_const) + + cnd = cond_v2.cond_v2(True, fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionTensorValueAccessInCond(self): + """Read tensors from collections inside of cond_v2 & use them.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + + def fn(): + x_read = ops.get_collection("x")[0] + y_read = ops.get_collection("y")[0] + return math_ops.add(x_read, y_read) + + cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionIntValueWriteInCond(self): + """Make sure Int writes to collections work inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + def true_fn(): + z = math_ops.add(x, y) + ops.add_to_collection("z", 7) + return math_ops.mul(x, z) + + def false_fn(): + z = math_ops.add(x, y) + return math_ops.mul(x, z) + + cnd = cond_v2.cond_v2( + True, true_fn, + false_fn) + self.assertEquals(cnd[0].eval(), 14) + + read_z_collection = ops.get_collection("z") + self.assertEquals(read_z_collection, [7]) + + +class CondV2ContainerTest(test.TestCase): + + def testContainer(self): + """Set containers outside & inside of cond_v2. + + Make sure the containers are set correctly for both variable creation + (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) + """ + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + v0 = variables.Variable([0]) + q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + def container(node): + return node.op.get_attr("container") + + self.assertEqual(compat.as_bytes(""), container(v0)) + self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) + + def true_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2t"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2t"), container(v2)) + self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(2.0) + + def false_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2f"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2f"), container(v2)) + self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(6.0) + + with ops.container("l1"): + cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + self.assertEquals(cnd_true[0].eval(), 2) + + cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + self.assertEquals(cnd_false[0].eval(), 6) + + v4 = variables.Variable([3]) + q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) + v5 = variables.Variable([4]) + q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v4)) + self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) + self.assertEqual(compat.as_bytes(""), container(v5)) + self.assertEqual(compat.as_bytes(""), container(q5.queue_ref)) + + +class CondV2ColocationGroupAndDeviceTest(test.TestCase): + + def testColocateWithBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + with ops.colocate_with(b.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testColocateWithInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn2(): + with ops.colocate_with(b.op): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant([2.0], name="d") + self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + + def testDeviceBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn(): + c = constant_op.constant(3.0) + self.assertEqual("/device:CPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:GPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testDeviceInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn2(): + with ops.device("/device:GPU:0"): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant(4.0) + self.assertEqual("/device:CPU:0", d.op.device) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index e88ad3dc32003ece2b8810661cd4db374196561c..4657807785d58727d34f37172bd30c56a5b7cde6 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -236,7 +236,7 @@ class CSVDatasetOp : public DatasetOpKernel { size_t num_parsed = 0; size_t num_selected_parsed = 0; - Status result = Status::OK(); + Status result; while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = @@ -329,6 +329,7 @@ class CSVDatasetOp : public DatasetOpKernel { size_t start = pos_; pos_++; // Starting quotation mark + Status parse_result; while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); @@ -351,8 +352,9 @@ class CSVDatasetOp : public DatasetOpKernel { if (errors::IsOutOfRange(s)) { // This was the last field. We are done *end_of_record = true; - return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, - earlier_pieces, include); + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; } else if (!s.ok()) { return s; } @@ -361,20 +363,24 @@ class CSVDatasetOp : public DatasetOpKernel { char next = buffer_[pos_]; pos_++; if (next == dataset()->delim_) { - return QuotedFieldToOutput( + parse_result.Update(QuotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); + out_tensors, earlier_pieces, include)); + return parse_result; } else if (next == '\n' || next == '\r') { *end_of_record = true; - Status s = QuotedFieldToOutput( + parse_result.Update(QuotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); + out_tensors, earlier_pieces, include)); if (next == '\r') SkipNewLineIfNecessary(); - return s; + return parse_result; } else if (next != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another quote"); + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); } } else { @@ -454,6 +460,8 @@ class CSVDatasetOp : public DatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector earlier_pieces; size_t start = pos_; + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); @@ -461,9 +469,10 @@ class CSVDatasetOp : public DatasetOpKernel { if (errors::IsOutOfRange(s)) { // Whatever we have is the last field of the last record *end_of_record = true; - return UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); + return parse_result; } else if (!s.ok()) { return s; // Surface all other errors to caller } @@ -472,66 +481,33 @@ class CSVDatasetOp : public DatasetOpKernel { char ch = buffer_[pos_]; if (ch == dataset()->delim_) { - Status s = UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); pos_++; - return s; + return parse_result; } if (ch == '\n' || ch == '\r') { // need special case to skip over first \n of record if the line // breaks are \r\n - Status s = UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); *end_of_record = true; pos_++; if (ch == '\r') SkipNewLineIfNecessary(); - return s; + return parse_result; } if (dataset()->use_quote_delim_ && ch == '"') { - // Advance pos_ to the next field anyway so that we can ignore - // errors gracefully if required. The caller of this will be able to - // call ParseOneField and continue with the rest of the record. - AdvanceToNextField(end_of_record); - return errors::InvalidArgument( - "Unquoted fields cannot have quotes inside"); + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); } // Otherwise, go to next character pos_++; } } - // Advances pos_ to the start of the next field, as delimited by delim, - // CRLF, or EOF, ignoring errors, and not keeping track of characters in - // the current field. - void AdvanceToNextField(bool* end_of_record) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - while (true) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - if (!s.ok()) { - *end_of_record = true; - return; - } - } - - char ch = buffer_[pos_]; - pos_++; - - if (ch == dataset()->delim_) { - return; - } - - if (ch == '\n' || ch == '\r') { - *end_of_record = true; - if (ch == '\r') SkipNewLineIfNecessary(); - return; - } - } - } - Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ba707d8d6e466561442f48e5dd7e8bdee20fb0f7..4e3f9801d7144695478d7fcf2fbc9ecf6e57117a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -30,6 +30,7 @@ py_test( "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -54,6 +55,19 @@ py_test( ], ) +py_test( + name = "cache_dataset_op_test", + size = "small", + srcs = ["cache_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "concatenate_dataset_op_test", size = "small", @@ -330,6 +344,26 @@ py_test( ], ) +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + ], +) + py_test( name = "reader_dataset_ops_test", size = "medium", @@ -339,8 +373,8 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -352,6 +386,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -441,6 +476,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -448,6 +484,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -478,10 +515,15 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2..1435503beb96104c0a845bb064165099c680613a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import math import time +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base @@ -40,7 +41,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase): +class BatchDatasetTest(test.TestCase, parameterized.TestCase): def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -427,9 +428,13 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, - num_parallel_calls=None, - num_parallel_batches=None): + @parameterized.named_parameters( + ("default", None, None), + ("sequential_calls", 1, None), + ("parallel_calls", 2, None), + ("parallel_batches", None, 10), + ) + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -500,19 +505,11 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatch(self): - return self._testMapAndBatchDatasetHelper() - - def testMapAndBatchWithParallelBatches(self): - return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) - - def testMapAndBatchWithSequentialCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) - - def testMapAndBatchWithParallelCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) - - def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + @parameterized.named_parameters( + ("even", False), + ("uneven", True), + ) + def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( @@ -532,12 +529,6 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchPartialBatch(self): - return self._testMapAndBatchPartialBatchHelper() - - def testMapAndBatchPartialBatchDropRemainder(self): - return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) - def testMapAndBatchYieldsPartialBatch(self): iterator = (dataset_ops.Dataset.range(10) .apply(batching.map_and_batch( @@ -614,7 +605,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testMapAndBatchDatasetFails(self): + def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -628,7 +619,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testMapAndBatchDatasetShapeMismatch(self): + def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index bd3e034211c4aa454e4f8f6b09f14935d7a3b35c..4fbfbfdbdd7ffd1019cef5bab7ffd5c149c37fcc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -68,7 +68,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda _: (0.0, 0.0), reduce_func=reduce_fn, - finalize_func=lambda x: x[0]) + finalize_func=lambda x, _: x) for i in range(1, 11): dataset = dataset_ops.Dataset.range(2 * i).apply( grouping.group_by_reducer( @@ -121,7 +121,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda x: ([0], 1), reduce_func=reduce_fn, - finalize_func=lambda x: x) + finalize_func=lambda x, y: (x, y)) for i in range(1, 11): dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f08216a303e2d7dee155ccadcdb9f42f1b24ea0f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental features of CacheDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheToFileDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def ds_fn(self): + return dataset_ops.Dataset.range(self.range_size).cache( + os.path.join(self.get_temp_dir(), + self.cache_file_prefix)).repeat(self.num_repeats) + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + def testCheckpointBeforeOneEpoch(self): + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointBeforeOneEpochThenRunFewSteps(self): + # Generate 8 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 8, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + # Restoring from checkpoint and running GetNext should return a + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + def testCheckpointAfterOneEpoch(self): + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointAfterOneEpochThenRunFewSteps(self): + # Generate 18 entries from iterator but save checkpoint after producing + # 15. + outputs = self.gen_outputs( + self.ds_fn, [15], + 18, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): + # Generate 13 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 13, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedWriterIterator(self): + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + self.ds_fn, [], + self.num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedMidwayWriterIterator(self): + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testUnusedCheckpointError(self): + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + + def testIgnoreCheckpointIfCacheWritten(self): + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 74b90ec7d1617d221888d1e1c56cf594c367ddf9..97b5e9416521dcad9ee5047a8275f8fd0142e338 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -162,9 +162,28 @@ class CsvDatasetOpTest(test.TestCase): expected_err_re='Unquoted fields cannot have quotes inside', record_defaults=record_defaults) + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 - inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 78ecce8f7daaf84002ae78d8d77820755b967d89..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index e0237198b7d47eb98eeffe88d28bf9681b2722c6..3b07ef290bc38daa37472ef8919f3350851fe370 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -24,9 +24,8 @@ import zlib import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -280,163 +279,8 @@ def _interleave(iterators, cycle_length): num_open -= 1 -class ReadBatchFeaturesTest(test.TestCase): - - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() - - def _read_batch_features(self, - filenames, - num_epochs, - batch_size, - reader_num_threads=1, - parser_num_threads=1, - shuffle=False, - shuffle_seed=None, - drop_final_batch=False): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.make_batched_features_dataset( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - num_epochs=self.num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads, - drop_final_batch=drop_final_batch).make_one_shot_iterator( - ).get_next() - - def _record(self, f, r): - example = example_pb2.Example( - features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[f])), - "record": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[r])), - "keywords": - feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) - - def _next_expected_batch(self, - file_indices, - batch_size, - num_epochs, - cycle_length=1): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - if cycle_length == 1: - next_records = _next_record(file_indices) - else: - next_records = _next_record_interleaved(file_indices, cycle_length) - for record in next_records: - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend( - [[batch_index, i] for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _verify_records(self, - sess, - batch_size, - file_index=None, - num_epochs=1, - interleave_cycle_length=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - - for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -444,33 +288,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -504,18 +348,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with same seed produces the same result. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) + shuffle_seed=5).make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) batch2 = self._run_actual_batch(outputs2, sess) @@ -525,18 +369,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with different seeds produces a different order. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=15) + shuffle_seed=15).make_one_shot_iterator().get_next() all_equal = True for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) @@ -552,13 +396,14 @@ class ReadBatchFeaturesTest(test.TestCase): for parser_num_threads in [2, 4]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads) - self._verify_records( + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( sess, batch_size, num_epochs=num_epochs, @@ -571,11 +416,11 @@ class ReadBatchFeaturesTest(test.TestCase): for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, - drop_final_batch=True) + drop_final_batch=True).make_one_shot_iterator().get_next() for _, tensor in self.outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..805a7c7b7384d53cc166a48ba243502ef8643280 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,218 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index bdc003a8a5bd646e1d5c598befa2694da512d0a9..520da7d6ff3ed50352a89c8a2d4f08122eb922dd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -17,10 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import time -from absl.testing import parameterized from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..25e9ea47b82dad479f041a7be37c984f96c95e0e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetSerializationTest( @@ -46,30 +48,104 @@ class ShuffleDatasetSerializationTest( def testShuffleCore(self): seed = 55 - range_limit = 10 - num_repeats = 5 + range_limit = 5 + num_repeats = 2 num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False + buffer_sizes = [1, 3, 5, 8, 10] # pylint: disable=cell-var-from-loop # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) class ShuffleAndRepeatTest( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 5c74ed6ae7210e8e22efb6e8fdb773397459ce1e..17b6644759e53f84b23e070a71267aa15dcffe49 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,6 +50,9 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -193,6 +197,45 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) + + class StatsDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 086661adb7603345be09a4c710d4fb2b170ac8f9..33b7a75046cf2acfa3d787833b907aa2b28dbdca 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -96,8 +96,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":gen_dataset_ops", ":interleave_ops", ":shuffle_ops", + ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -106,12 +108,12 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -142,6 +144,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", ], diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index b9393de4e90ae2597045b29070934b94e18cfcbd..052618e08c8f204613db5a20d42e078f17f12840 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation def dense_to_sparse_batch(batch_size, row_shape): @@ -101,10 +103,7 @@ class UnbatchDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unbatch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -218,6 +217,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -250,12 +251,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) return filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -284,6 +289,8 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) return filter_irregular_batches(batch_size)(batched) @@ -309,11 +316,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -490,10 +494,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 6c21e489f7c35484ebacd465e3b46d6920df5933..5f5513849cb29a18b86ba8bcee1ab6c9c60674cb 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse def ignore_errors(): @@ -64,10 +62,7 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 3a07df572748e464284f580d67e3a664e71acdfe..0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -64,10 +64,7 @@ def get_single_element(dataset): nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), - output_shapes=nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)))) + **dataset_ops.flat_structure(dataset))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ea229b5b27b117984e508fa4edc6f1cf713008b4..4068a2ffa5ab877c372a6f32e3430812aa138391 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -21,12 +21,9 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -273,67 +270,27 @@ class GroupByReducerDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret) - if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func, "tf.contrib.data.group_by_reducer()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" + % (wrapped_func.output_types, wrapped_func.output_shapes)) + self._key_func = wrapped_func.function def _make_init_func(self, init_func): """Make wrapping Defun for init_func.""" - - @function.Defun(dtypes.int64) - def tf_init_func(key): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - ret = init_func(key) - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + init_func, "tf.contrib.data.group_by_reducer()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + self._init_func = wrapped_func.function + self._state_classes = wrapped_func.output_classes + self._state_shapes = wrapped_func.output_shapes + self._state_types = wrapped_func.output_types def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" @@ -343,83 +300,47 @@ class GroupByReducerDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_reduce_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_reduce_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = reduce_func(nested_state_args, nested_input_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - # Extract shape information from the returned values. - flat_new_state = nest.flatten(ret) - flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) - - # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as(self._state_types, - [t.dtype for t in flat_new_state]))) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, - [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - # Use the private method that will execute `tf_reduce_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access - + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.group_by_reducer()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(wrapped_func.output_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + for new_state_type, state_type in zip( + nest.flatten(wrapped_func.output_types), + nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, wrapped_func.output_types)) + + # Extract shape information from the returned values. flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ - old.most_specific_compatible_shape(new) - for old, new in zip(flat_state_shapes, flat_new_state_shapes) + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False - for old_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if old_shape.ndims is not None and ( + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( weakened_shape.ndims is None or - old_shape.as_list() != weakened_shape.as_list()): + original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break @@ -427,50 +348,19 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._reduce_func = tf_reduce_func + self._reduce_func = wrapped_func.function self._reduce_func.add_to_graph(ops.get_default_graph()) def _make_finalize_func(self, finalize_func): """Make wrapping Defun for finalize_func.""" - - @function.Defun(*(nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes)))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes))): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - - ret = finalize_func(nested_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + finalize_func, "tf.contrib.data.group_by_reducer()", + input_classes=self._state_classes, input_shapes=self._state_shapes, + input_types=self._state_types) + self._finalize_func = wrapped_func.function + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types @property def output_classes(self): @@ -495,10 +385,7 @@ class GroupByReducerDataset(dataset_ops.Dataset): init_func=self._init_func, reduce_func=self._reduce_func, finalize_func=self._finalize_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class GroupByWindowDataset(dataset_ops.Dataset): @@ -516,64 +403,39 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_window_size_func(self, window_size_func): """Make wrapping Defun for window_size_func.""" - - @function.Defun(dtypes.int64) - def tf_window_size_func(key): - key.set_shape([]) - window_size = ops.convert_to_tensor( - window_size_func(key), dtype=dtypes.int64) - if window_size.dtype != dtypes.int64: - raise ValueError( - "`window_size_func` must return a single tf.int64 tensor.") - return window_size - - self._window_size_func = tf_window_size_func - self._window_size_func.add_to_graph(ops.get_default_graph()) + def window_size_func_wrapper(key): + return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + window_size_func_wrapper, "tf.contrib.data.group_by_window()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`window_size_func` must return a single tf.int64 scalar tensor.") + self._window_size_func = wrapped_func.function def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + def key_func_wrapper(*args): + return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 scalar tensor.") + self._key_func = wrapped_func.function def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" - - @function.Defun(dtypes.int64, dtypes.variant) - def tf_reduce_func(key, window_dataset_variant): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) + def reduce_func_wrapper(key, window_dataset_variant): + """Wrapper that converts between tf.variant and Dataset objects.""" window_dataset = _VariantDataset( window_dataset_variant, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) - if not isinstance(window_dataset, dataset_ops.Dataset): - raise TypeError("`window_dataset` must return a `Dataset` object.") output_dataset = reduce_func(key, window_dataset) if not isinstance(output_dataset, dataset_ops.Dataset): raise TypeError("`reduce_func` must return a `Dataset` object.") @@ -582,8 +444,12 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._output_shapes = output_dataset.output_shapes return output_dataset._as_variant_tensor() # pylint: disable=protected-access - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func_wrapper, "tf.contrib.data.reduce_by_window()", + input_classes=(ops.Tensor, ops.Tensor), + input_shapes=(tensor_shape.scalar(), tensor_shape.scalar()), + input_types=(dtypes.int64, dtypes.variant)) + self._reduce_func = wrapped_func.function @property def output_classes(self): @@ -606,10 +472,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class Reducer(object): diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index be66fbac50753c8f54b62dd615ee60804f4cf20d..70153ac575758f16beff373941dfefb32bd342cf 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -24,7 +24,6 @@ from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -171,10 +170,7 @@ class DirectedInterleaveDataset(dataset_ops.Dataset): return gen_dataset_ops.directed_interleave_dataset( self._selector_input._as_variant_tensor(), [data_input._as_variant_tensor() for data_input in self._data_inputs], - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index cad41bce2961f29a7591fe3d382d1ab35a6b38b4..2ca3805d6609a82aa733da36d84c7fb58921d764 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -19,8 +19,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -62,10 +60,7 @@ class OptimizeDataset(dataset_ops.Dataset): return gen_dataset_ops.optimize_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._optimizations, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc..e670c4c8354f4067eb21c9b1fce708147c162967 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -39,10 +37,7 @@ class RandomDataset(dataset_ops.Dataset): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f938153f5f8c8becc5877a667117fd6facd3e428..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -754,6 +755,8 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) + if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) else: diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index bad6edd5147d832228c412919f1e6e782aafc40f..182a5c6ff36fcda8c9e2c522cce07bed0c2daec9 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -291,4 +291,4 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): # TODO(joelshor): Simplify fraction, if possible. a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m \ No newline at end of file + return a_i, m diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index e911ad0fa0541f2d8b991d66182dd002c2ecaab0..ea9dcfe68fa2630d915323fa295031af7d48cdfb 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -22,7 +22,6 @@ import collections from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -67,102 +66,45 @@ class _ScanDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_scan_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_scan_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - print(self._state_classes) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - print(input_dataset.output_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = scan_func(nested_state_args, nested_input_args) - if not isinstance(ret, collections.Sequence) or len(ret) != 2: - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - new_state, output_value = ret - - # Extract and validate class information from the returned values. - for t, clazz in zip( - nest.flatten(new_state), nest.flatten(self._state_classes)): - if not isinstance(t, clazz): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, - nest.pack_sequence_as( - self._state_types, - [type(t) for t in nest.flatten(new_state)]))) - self._output_classes = sparse.get_classes(output_value) - - # Extract shape information from the returned values. - flat_new_state_shapes.extend( - [t.get_shape() for t in nest.flatten(new_state)]) - self._output_shapes = nest.pack_sequence_as( - output_value, [t.get_shape() for t in nest.flatten(output_value)]) - - # Extract and validate type information from the returned values. - for t, dtype in zip( - nest.flatten(new_state), nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as( - self._state_types, - [t.dtype for t in nest.flatten(new_state)]))) - self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in nest.flatten(output_value)]) - - # Serialize any sparse tensors. - new_state = nest.pack_sequence_as(new_state, [ - t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) - ]) - output_value = nest.pack_sequence_as(output_value, [ - t for t in nest.flatten( - sparse.serialize_sparse_tensors(output_value)) - ]) - return nest.flatten(new_state) + nest.flatten(output_value) - - # Use the private method that will execute `tf_scan_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + scan_func, "tf.contrib.data.scan()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + if not ( + isinstance(wrapped_func.output_types, collections.Sequence) and + len(wrapped_func.output_types) == 2): + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + + new_state_classes, self._output_classes = wrapped_func.output_classes + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(new_state_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, new_state_classes)) + + # Extract and validate type information from the returned values. + new_state_types, self._output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(new_state_types), nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, new_state_types)) + + # Extract shape information from the returned values. + new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -178,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset): break if need_to_rerun: - # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun - # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._scan_func = tf_scan_func + self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): @@ -193,10 +133,7 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index f35795abd38000b13cec0f08596e2ff66e86286c..d7f8a73fe3d67bb83e44e962832ce34c116aef66 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): count=self._count, seed=self._seed, seed2=self._seed2, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2..f935beb1a9e85d4901857e7781a5ed8473838fa5 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -19,7 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -43,10 +42,7 @@ class _SlideDataset(dataset_ops.Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, stride=self._stride, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3cbaab5affd7397213b0fbb6b0682db92b99d591..3c82a03df1745d855b2d3f918f7bbde113600556 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,8 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -97,10 +95,7 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -176,6 +171,27 @@ def latency_stats(tag): return _apply_fn +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" @@ -189,10 +205,7 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 56f67e1766bbaff680bdff6b939df0c3ba68c679..bb49604d4de90d726418684124608438aa33e6cf 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -22,8 +22,6 @@ import threading from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.ops import resource_variable_ops @@ -69,10 +67,7 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return gen_dataset_ops.thread_pool_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 765ef3f9b6d42c9d7af3ce4916731d37d65c9260..4ce6ddede8350735636fd152fdc9df0319265990 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -65,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 9624abd1997b36c4424f525366c658fe24b25f3a..9dfb8552f1b0f058b44f8ed09c2ed681367293d5 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -77,6 +77,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -312,7 +313,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", - "noguitar", # TODO(b/109653107): test is flaky. ], ) @@ -591,3 +591,22 @@ cuda_py_test( "notsan", ], ) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index cef0a2907b85d230606eb530a0e94549b6b95e53..900aa10e93e8881aa236bac8a2873d5c5531c6f6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -31,6 +31,7 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import coordinator from tensorflow.python.training import device_util @@ -343,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return math_ops.add_n(self.unwrap(tower_local_var)) + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" if isinstance(val, values.TowerLocalVariable): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 09b6d4a515ab46879520f304cd5ef60469512380..7f4bab9d93814eb70a2a1586fc291a16b2766b90 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), distribute_lib.UpdateContext(self._device): return fn(*args, **kwargs) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" with ops.device(self._device): diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9572ade8e497fa13a7ca0746399d3e0237ee79fd..aca544b7e7e3c6f706377de9846881bea19b92d0 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -238,17 +238,6 @@ class DistributedVariable(DistributedDelegate): pass -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) ops.register_dense_tensor_like_type(DistributedVariable) @@ -342,6 +331,20 @@ class MirroredVariable(DistributedVariable, Mirrored, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(MirroredVariable, + _tensor_conversion_mirrored) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -431,6 +434,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function for TowerLocalVariable which allows as_ref to +# be true. +def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(TowerLocalVariable, + _tensor_conversion_tower_local) + + def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 1c95758d96aba47e9581dde6411763e98b99a968..b0bd92c7b054b52b071e5d7601bdc48117464822 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -966,6 +966,18 @@ class TowerLocalVariableTest(test.TestCase): save_path = self._save_normal() self._restore_tower_local_sum(save_path) + def testTensorConversion(self): + with context.graph_mode(): + _, tower_local = _make_tower_local("sum") + converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + + converted = ops.internal_convert_to_tensor(tower_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 23d9dbcd91a25e7cbb5d6cfea5d63ba8412f4255..ad00d1734dd14ed846522a33d888a5387cb25cc6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "bijectors_py", srcs = glob(["python/ops/bijectors/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/linalg:linalg_py", @@ -42,6 +49,13 @@ py_library( py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ ":bijectors_py", @@ -940,6 +954,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -1118,6 +1151,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "scale_tril_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -1235,6 +1287,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 802538ba97578ce6cfe7e3555963ecd2fd014a66..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..566a7b3dff9b5d97a1cb143e0b32fc15984c3a02 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes() + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 45760a29ee42835da69ef63803ccec7ce82a5a8f..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) - self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), - rtol=1e-4, - atol=0.) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6428a68702274fae384ae3de6d03f7ca126e2346 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 11ca90c4833d84b092f0b43a8f5404e3a11450cd..bb9b8043b2233b2109f51b5dde188d088fdb0d39 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Autoregressive(distribution_lib.Distribution): @@ -107,6 +108,14 @@ class Autoregressive(distribution_lib.Distribution): https://arxiv.org/abs/1606.05328 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution_fn, sample0=None, diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index 4714caad69ee4341d259f6677decdd5842931834..519077bc9ab1063a1135486cfae34656f3f68157 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -71,6 +72,14 @@ class BatchReshape(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, batch_shape, @@ -352,6 +361,14 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensor_util.constant_value_as_shape(new_shape) @@ -384,6 +401,14 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None): return expanded_new_shape, batch_shape_static, validations +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" if batch_shape.shape.ndims is not None: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 4965381ef33e14cef0e0339341d50c943d412d8f..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,6 +24,7 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@ -36,12 +37,14 @@ @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Softsign @@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -64,6 +67,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * @@ -75,12 +79,14 @@ from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index c9e31d7712f09f6c4b4cc6ae51a34c42a19c291d..4d6a46e7358933fdf512f49eae2673f35953c90a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "AbsoluteValue", @@ -70,6 +71,14 @@ class AbsoluteValue(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index b4c2939eb914d50475ba6b1c1e979a804090f641..25f29452c3949600b8a4153a8585dd7269bd3b2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,6 +37,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -97,6 +106,14 @@ class Affine(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale_identity_multiplier=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 59f9742d576a7804f401d3a47ba31ae61d6c6e54..91301f15ad87e133777371b346864ecf7b964f27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util import deprecation __all__ = [ @@ -88,6 +89,14 @@ class AffineLinearOperator(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py index cd792e2c8cf48602daf9fb5eb56b8c34bac050c7..460d906231bd30f8cec4fe21d42afe7b2a05805e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -52,6 +53,14 @@ class AffineScalar(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index 224cec8a63dba53a528490117efac890312fe8d5..f19f147dd645b4f805f1905899b44293284d4225 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -34,6 +35,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _undo_batch_normalization(x, mean, variance, @@ -128,6 +137,14 @@ class BatchNormalization(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batchnorm_layer=None, training=True, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 16f959560ce0f171035b3ef0bd80b16dae1cc654..910774ea5bb4106a948567144c46c6db23a2c6e0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,10 +32,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -142,6 +159,14 @@ class Chain(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijectors=None, validate_args=False, name=None): """Instantiates `Chain` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 268c8d03426d435dc38412ac1bd05c674bd05d2b..8267ee7df89f69f8d610e9507e0cca9f4a5d4323 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -69,6 +70,14 @@ class CholeskyOuterProduct(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 9fc1bbf052b419d07a9db149b990c2b80190d72b..07627e1e45eae6b63d830b2adf036bdc3b1d2895 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors import power_transform +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Exp(power_transform.PowerTransform): over the event space. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="exp"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9ca27e519bc312813668bf621a875838f12a0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util +from tensorflow.python.util import deprecation + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index e656a258e56e71898ecb719dd2af876f158cf799..71e562a927a30a17d695b81c566f981db7553ad9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Gumbel", @@ -45,6 +46,14 @@ class Gumbel(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=0., scale=1., diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index 2bde956d1345129285acae4684256c5ac828b9a1..1504bd27204f728c0cb519159230e945128c4740 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -43,6 +44,14 @@ class Inline(bijector.Bijector): The above example is equivalent to the `Bijector` `Exp()`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, forward_fn=None, inverse_fn=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 84a3289ba2160ed22a2bc7030dd612ba9ca6f6df..a648676d4b1956e5c27f67a71e6bd93d0d7fc97d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Invert", @@ -40,6 +41,14 @@ class Invert(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijector, validate_args=False, name=None): """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py index 97000c17262d3efdef10274711364c2bc2083bd4..33b75a04d34fdd01bc0d854d4e5b9c45a737b122 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -44,6 +45,14 @@ class Kumaraswamy(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 83667b0e80cfcc1c4f0617cdc739221f24439665..b8f2a4b2c731bdaee78692c036fb9f2fba4e3760 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -186,6 +187,14 @@ class MaskedAutoregressiveFlow(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift_and_log_scale_fn, is_constant_jacobian=False, @@ -296,6 +305,14 @@ MASK_INCLUSIVE = "inclusive" MASK_EXCLUSIVE = "exclusive" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): """Generate the slices for building an autoregressive mask.""" # TODO(b/67594795): Better support of dynamic shape. @@ -313,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): return slices +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_mask(num_blocks, n_in, n_out, @@ -327,6 +352,14 @@ def _gen_mask(num_blocks, return mask +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_dense(inputs, units, num_blocks=None, @@ -399,6 +432,14 @@ def masked_dense(inputs, return layer.apply(inputs) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_autoregressive_default_template( hidden_layers, shift_only=False, @@ -515,6 +556,14 @@ def masked_autoregressive_default_template( "masked_autoregressive_default_template", _fn) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): """Clips input while leaving gradient unaltered.""" with ops.name_scope(name, "clip_by_value_preserve_grad", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py index 71903f705232f0c5e5e0b3271550b4ef938c4f9d..49e6192f067edec4890dcfa107876a5104c14dd4 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -55,6 +56,14 @@ class MatrixInverseTriL(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="matrix_inverse_tril"): """Instantiates the `MatrixInverseTriL` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py index 3f03592f314cc13e8a9ea7e2ae18c5bb1f14e74f..fb393218b6b47764f45b5055bbf15cc17aba219e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -57,6 +58,14 @@ class Ordered(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="ordered"): super(Ordered, self).__init__( forward_min_event_ndims=1, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 12a16a3f2ba3da53077307fd97d3f10d99b2c81f..f182a1adcbb6b11af2376cd271f903d50e50f1a0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -74,6 +75,14 @@ class Permute(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, permutation, validate_args=False, name=None): """Creates the `Permute` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index 71f123f2a998458edaa9c8da07ea2932f62625ca..16264fe728a334db347304500767ce5876f9db7e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -41,6 +42,14 @@ class PowerTransform(bijector.Bijector): This bijector is equivalent to the `Exp` bijector when `c=0`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, power=0., validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 66e8a5b9b356867424d1d47efaf848fc6903c371..773ae2446118051a61636bc21de6b81dfacda746 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -126,6 +127,14 @@ class RealNVP(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, num_masked, shift_and_log_scale_fn, @@ -228,6 +237,14 @@ class RealNVP(bijector.Bijector): return math_ops.reduce_sum(log_scale, axis=-1) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def real_nvp_default_template( hidden_layers, shift_only=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 5497c422e4d51e259435692dac722f801e8844ac..c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,10 +37,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _static_ndims_from_shape(shape): return shape.shape.with_rank_at_least(1)[0].value +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _ndims_from_shape(shape): return array_ops.shape(shape)[0] @@ -86,6 +103,14 @@ class Reshape(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbe8665781211ca803feb8bf5a8c04fb0b969e8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal +from tensorflow.python.util import deprecation + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index 5df8c886315ff75cdc884e3b9b4665fb64bb109d..194b318fce31a13f84e7b664b58cebb24fc9a264 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,6 +32,14 @@ __all__ = [ class Sigmoid(bijector.Bijector): """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 2a32e8abcde940b0056b0faf2955ec1b3bd71803..241fba2cb7ec33b7b02c1ca79051f1b826d7d2aa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" return array_ops.where( @@ -88,6 +97,14 @@ class SinhArcsinh(bijector.Bijector): `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, skewness=None, tailweight=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index f52b91550edff7390d8094a4508d862674e85d59..20ee0d340833d5c5275e2ab52a89dcdf7198add1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -60,6 +61,14 @@ class SoftmaxCentered(bijector.Bijector): makes the (forward) image non-open and the theorem does not directly apply. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softmax_centered"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 96a938c803418ff818f9c531754b47ba1eb8667a..3df84ef8b04c2c8f6be91ecd1c972ad1484b4285 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -80,6 +81,14 @@ class Softplus(bijector.Bijector): "hinge_softness": ( "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, hinge_softness=None, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py index b4a658c171b8313358754228aabbfa4bf93fd84d..f96a4bb01de59a21107b9e7c14f929e13e358ac9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -22,6 +22,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -51,6 +52,14 @@ class Softsign(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softsign"): super(Softsign, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 2ccfdc95970e387e708603e2614ad29fb6a18db3..294460a80f6209797831ea361e64efe677f71e59 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ class Square(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="square"): """Instantiates the `Square` bijector. @@ -81,4 +90,3 @@ class Square(bijector.Bijector): is_valid = check_ops.assert_non_negative( t, message="All elements must be non-negative.") return control_flow_ops.with_dependencies([is_valid], t) - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7a3b026b8dcc31bed49c489d77b9c184f463cb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index a22560fe80298b762795e7b0e7aea2db55823065..8903a70d98ae144731b12047e5074d0450b59378 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Weibull(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale=1., concentration=1., diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index e4944beedcbca09b5eabd4daf1445ce4503b1c80..b349e5966dd750fdf96c0b211dce02658c9400b7 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation _binomial_sample_note = """ @@ -42,6 +43,14 @@ to integer values. """ +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _bdtr(k, n, p): """The binomial cumulative distribution function. @@ -130,6 +139,14 @@ class Binomial(distribution.Distribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 23b6a83c17d58652001543047febeebabba0c69f..cb5223b0557080e10bf24c3e1cb432f15fd5e7e3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Cauchy", @@ -92,6 +93,14 @@ class Cauchy(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 686ae1ba74641e2b7b76667e512fa6453477a8da..e9a7b39070f3d76693ad54852ed0847a0980d2a6 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma +from tensorflow.python.util import deprecation __all__ = [ @@ -63,6 +64,14 @@ class Chi2(gamma.Gamma): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, @@ -114,6 +123,14 @@ class Chi2(gamma.Gamma): class Chi2WithAbsDf(Chi2): """Chi2 with parameter transform `df = floor(abs(df))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index c44c76a133817640449ba126bb8ca25abadba5e6..ad853ee293f86565c1af601214522f53d936b70a 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Deterministic", @@ -43,6 +44,14 @@ __all__ = [ class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -203,6 +212,14 @@ class Deterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -308,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 98edd337fe02ffbf53c6ecd9ebda9424231ea2fe..bdec6527d5378d6e86aa8e6279cc6ee672083e56 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -30,6 +31,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def estimator_head_distribution_regression(make_distribution_fn, label_dimension=1, logits_dimension=None, @@ -77,6 +86,14 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, make_distribution_fn, label_dimension, diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index e1e42ee95d200df30c2c8a53a89cb5b7e9c4d17c..d62f024aa2a081f0ec231015af1f26a8851518e9 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Geometric(distribution.Distribution): @@ -55,6 +56,14 @@ class Geometric(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, logits=None, probs=None, diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 9d94fd11c62ce6ecd3d7daee35447bece2b4b2fb..acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class _Gumbel(distribution.Distribution): @@ -96,6 +97,14 @@ class _Gumbel(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 9c96254d1c0a593b955231132330931ff5f4ad07..b02c4031069191592b8acc1a90313450f98af6d7 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math +from tensorflow.python.util import deprecation __all__ = [ @@ -85,6 +86,14 @@ class HalfNormal(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cd6eaa8407477b4ed92f169bc0d2d80644d7c956..0672702b96c1eb81c176774554df3f5922a0319e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.util import deprecation class Independent(distribution_lib.Distribution): @@ -94,6 +95,14 @@ class Independent(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): @@ -258,6 +267,14 @@ class Independent(distribution_lib.Distribution): @kullback_leibler.RegisterKL(Independent, Independent) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 208057b34db2881b5c9c2adb102d02a87a333007..70d050d7a647b38928ddb1c788db0e6957ac0f03 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, @@ -274,6 +283,14 @@ class InverseGamma(distribution.Distribution): class InverseGammaWithSoftplusConcentrationRate(InverseGamma): """`InverseGamma` with softplus of `concentration` and `rate`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 0ff989fc952c6fb3f54dad9a943eb36a0494a3be..e3712dd84e36609d6bba4a5a39866046c0c8d1d8 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -40,6 +41,14 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. @@ -123,6 +132,14 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 27aa863440574eb0cdb5c7ae326e877d472999ad..02e3bad51ee48188acf83cb09359861c9e6932c7 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class Logistic(distribution.Distribution): @@ -91,6 +92,14 @@ class Logistic(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index bfb53a06c011cec60cf5b2132e4b1106128a1ece..3b7114ef067c0aaede23fff04c40d1dc6e830f1c 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Mixture(distribution.Distribution): @@ -66,6 +67,14 @@ class Mixture(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, cat, components, diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 112eefd3691815ead19d59bc3aef5909b27ed169..8ffee940d03c9a5204f2ac6f7acd9ea482adae1a 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class MixtureSameFamily(distribution.Distribution): @@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mixture_distribution, components_distribution, @@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution): return x +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" z = x - y diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index d2beb2aff0481eb4ec3a3abbf44fad5efff8eedd..cd0c282ba6cebf784261a4e821f36ce4eed98fe0 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -218,6 +227,14 @@ class MultivariateNormalDiag( class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale_diag, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 5117379b047f5e510a8a1a5490ddf76ee93d9d74..d8401801f21afbe8fd042053c6a38a31a2539438 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 57f47db50c496f1e3e80d8177560b1bab594eb56..dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -112,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, covariance_matrix=None, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 6a0383db02555274239ee0b1845f24a705270d84..efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator( @kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, MultivariateNormalLinearOperator) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index c809ef3c1cb5b8b9cd892b98d81e57710807d0aa..d9110947ecdbba1a63669573f46db17b02e512ab 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalTriL( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_tril=None, diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 2bd11e24b315e044624344580108a232d1b6da89..6acfc5746a0cc20e916de81b71f90e08d8d91ad5 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class NegativeBinomial(distribution.Distribution): @@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution): * `n!` is the factorial of `n`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 3e44c10fab726ad1299cc852a5e1391fecb8b390..0c762f17c9b770ecada57b6ce60a4825ba374dd9 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class OneHotCategorical(distribution.Distribution): @@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, logits=None, @@ -233,6 +242,14 @@ class OneHotCategorical(distribution.Distribution): @kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index 04de8106ee0c06f4bc888964e053eb3123f3dab3..3d055085cc7386e57a71aa310458b7666bb9a396 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Poisson", @@ -65,6 +66,14 @@ class Poisson(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, rate=None, log_rate=None, diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 7b10ba998f0ceac37571524ce858bbd4c87455fe..7a7ad1be35b80ff0f000181ea0778ab282a8220f 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -42,6 +43,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_gauss_hermite( loc, scale, quadrature_size, validate_args=False, name=None): # pylint: disable=unused-argument @@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_quantiles( loc, scale, quadrature_size, validate_args=False, name=None): @@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): validate_args=True) """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 5ac6c34b538016af376f53aa5a889e78c1f65f5f..ef3bdfa75fcaa8df17db1238ceadadf788601356 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distributions from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = ["QuantizedDistribution"] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _logsum_expbig_minus_expsmall(big, small): """Stable evaluation of `Log[exp{big} - exp{small}]`. @@ -228,6 +237,14 @@ class QuantizedDistribution(distributions.Distribution): https://arxiv.org/abs/1711.10433 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, low=None, diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 4182ca2b56ea80dba71787b006a1652e0f979694..7e1f64dc425e6a576bfbe1bb456901fddfac26e1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class RelaxedBernoulli(transformed_distribution.TransformedDistribution): @@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Gumbel-Softmax. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, temperature, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 5414f347cd65e2d3327d1934cbc7a91e7f780fc5..9b5bd7576f2a3c364e21da76dd3905a8c6e35829 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class ExpRelaxedOneHotCategorical(distribution.Distribution): @@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -368,6 +377,14 @@ class RelaxedOneHotCategorical( A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 6a7f28713acefd2285b07a212e2e47a6db1ae5e1..4f348be2806aa3ade7c1ea2a7bc68ca26db6447f 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _DistributionShape(object): @@ -166,6 +167,14 @@ class _DistributionShape(object): "free," i.e., during graph construction. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batch_ndims=None, event_ndims=None, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index a764544932cea8a624820153e383595fec9d7fc6..a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", @@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 8d4914e16cd3748e81e3d9b3be8b35f64a1c6f0d..ece03fe4aab3cc3046e0958d883ca9388517b94b 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_gauss_hermite( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_quantiles( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -318,6 +335,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): https://arxiv.org/abs/1801.03080 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mix_loc, temperature, @@ -779,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return array_ops.reshape(p, shape=expand_shape) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): @@ -812,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args): return param +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -850,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine): return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: @@ -876,6 +925,14 @@ def interpolate_loc(grid, loc): return [x[..., k] for k in range(deg)] # list(shape:[B, e]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: @@ -892,6 +949,14 @@ def interpolate_scale(grid, scale): ])[0] for q in range(deg)] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def linop_scale(w, op): # We assume w > 0. (This assumption only relates to the is_* attributes.) with ops.name_scope("linop_scale", values=[w]): @@ -927,6 +992,14 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] @@ -935,6 +1008,14 @@ def concat_vectors(*args): return [val for vec in args_ for val in vec] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -944,11 +1025,27 @@ def add(x, y): return x + y +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def softmax(x, axis, name=None): """Equivalent to tf.nn.softmax but works around b/70297725.""" with ops.name_scope(name, "softmax", [x, axis]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index a75b3f3df1f2867f214f47051fa358b79a52a35e..73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -116,6 +117,14 @@ class VectorExponentialDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index a7d4c55be93f6190ae4d6976030190f27dcfe48f..9a47b4855763a25b484ad04a3415d191f19256f7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = ["VectorExponentialLinearOperator"] @@ -138,6 +139,14 @@ class VectorExponentialLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 4a53e7a621f27382d2995798f724392d34459670..e68ddc569c95ff63760b4b2f6d7a92f17240a558 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -151,6 +152,14 @@ class VectorLaplaceDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 0566e04fece6f9ca0de6903ce5c424eccbc003cd..3923161a332a77e4eaab8d65d96fd8c278c872ec 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index bb33cd0762a368eb7e53f1623ede9231e80f0b14..49ffff24caec8d6c525f65f06796d10548d5ec40 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "VectorSinhArcsinhDiag", @@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 21f84dcbdea8b422dd45fadeac1bb8b2804c551f..f289b39e51aff36780541a0545ed9e6cfe21dd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation class _VectorStudentT(transformed_distribution.TransformedDistribution): @@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, loc=None, diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 88d4280759da7ca685056f4d41cf8dc51393c9f3..f1accaaa4c920344608015c792a2c3606de1337f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "WishartCholesky", @@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution): this class. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale_operator, @@ -501,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -617,6 +634,14 @@ class WishartFull(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 1d9371c7ac405dbf0ec40210270b90f2cf9b9a25..6f02c90368d966b8cf8d0dee09f9d2a5013c90c1 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,6 +11,8 @@ py_library( "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", + "//tensorflow/contrib/eager/python/examples/revnet", + "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", "//tensorflow/contrib/eager/python/examples/spinn:data", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb index 4fe3a0e3f3d431684973a9251aa3d92bf2010444..5749f22ac58e0a012ed7e3fec4dfe2913d3f8273 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -68,7 +68,7 @@ "# simply construct the object. Most layers take as a first argument the number\n", "# of output dimensions / channels.\n", "layer = tf.keras.layers.Dense(100)\n", - "# The number of input dimensionss is often unnecessary, as it can be inferred\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", "# the first time the layer is used, but it can be provided if you want to \n", "# specify it manually, which is useful in some complex models.\n", "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" @@ -267,7 +267,7 @@ " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", " * `call`, where you do the forward computation\n", "\n", - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified." + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." ] }, { diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..bfb53cfff86650c28fdd934763b1fb40cc5c796c --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -0,0 +1,76 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "blocks", + srcs = ["blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "revnet", + srcs = ["revnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "large", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "blocks_test", + size = "large", + srcs = ["blocks_test.py"], + additional_deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "revnet_test", + size = "large", + srcs = ["revnet_test.py"], + additional_deps = [ + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fb4f9f068f062802cda4610ced01c50da3836e04 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -0,0 +1,335 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Building blocks with manual backward gradient computation. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops + + +class RevBlock(tf.keras.Model): + """Single reversible block containing several `_Residual` blocks. + + Each `_Residual` block in turn contains two _ResidualInner blocks, + corresponding to the `F`/`G` functions in the paper. + """ + + def __init__(self, + n_res, + filters, + strides, + input_shape, + batch_norm_first=False, + data_format="channels_first", + bottleneck=False, + fused=True): + """Initialize RevBlock. + + Args: + n_res: number of residual blocks + filters: list/tuple of integers for output filter sizes of each residual + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + """ + super(RevBlock, self).__init__() + self.blocks = tf.contrib.checkpoint.List() + for i in range(n_res): + curr_batch_norm_first = batch_norm_first and i == 0 + curr_strides = strides if i == 0 else (1, 1) + block = _Residual( + filters, + curr_strides, + input_shape, + batch_norm_first=curr_batch_norm_first, + data_format=data_format, + bottleneck=bottleneck, + fused=fused) + self.blocks.append(block) + + if data_format == "channels_first": + input_shape = (filters, input_shape[1] // curr_strides[0], + input_shape[2] // curr_strides[1]) + else: + input_shape = (input_shape[0] // curr_strides[0], + input_shape[1] // curr_strides[1], filters) + + def call(self, h, training=True): + """Apply reversible block to inputs.""" + + for block in self.blocks: + h = block(h, training=training) + return h + + def backward_grads_and_vars(self, x, y, dy, training=True): + """Apply reversible block backward to outputs.""" + + grads_all = [] + vars_all = [] + + for i in reversed(range(len(self.blocks))): + block = self.blocks[i] + y_inv = x if i == 0 else block.backward(y, training=training) + dy, grads, vars_ = block.backward_grads_and_vars( + y_inv, dy, training=training) + grads_all += grads + vars_all += vars_ + + return dy, grads_all, vars_all + + +class _Residual(tf.keras.Model): + """Single residual block contained in a _RevBlock. Each `_Residual` object has + two _ResidualInner objects, corresponding to the `F` and `G` functions in the + paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + bottleneck=False, + fused=True): + super(_Residual, self).__init__() + + self.filters = filters + self.strides = strides + self.axis = 1 if data_format == "channels_first" else 3 + if data_format == "channels_first": + f_input_shape = (input_shape[0] // 2,) + input_shape[1:] + g_input_shape = (filters // 2, input_shape[1] // strides[0], + input_shape[2] // strides[1]) + else: + f_input_shape = input_shape[:2] + (input_shape[2] // 2,) + g_input_shape = (input_shape[0] // strides[0], + input_shape[1] // strides[1], filters // 2) + + factory = _BottleneckResidualInner if bottleneck else _ResidualInner + self.f = factory( + filters=filters // 2, + strides=strides, + input_shape=f_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused) + self.g = factory( + filters=filters // 2, + strides=(1, 1), + input_shape=g_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused) + + def call(self, x, training=True, concat=True): + """Apply residual block to inputs.""" + + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + f_x2 = self.f.call(x2, training=training) + # TODO(lxuechen): Replace with simpler downsampling + x1_down = ops.downsample( + x1, self.filters // 2, self.strides, axis=self.axis) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + y1 = f_x2 + x1_down + g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error + y2 = g_y1 + x2_down + if not concat: # Concat option needed for correct backward grads + return y1, y2 + return tf.concat([y1, y2], axis=self.axis) + + def backward(self, y, training=True): + """Reconstruct inputs from outputs; only valid when stride 1.""" + + assert self.strides == (1, 1) + + y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) + g_y1 = self.g.call(y1, training=training) + x2 = y2 - g_y1 + f_x2 = self.f.call(x2, training=training) + x1 = y1 - f_x2 + + return tf.concat([x1, x2], axis=self.axis) + + def backward_grads_and_vars(self, x, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + + with tf.GradientTape(persistent=True) as tape: + x_stop = tf.stop_gradient(x) + x1, x2 = tf.split(x_stop, num_or_size_splits=2, axis=self.axis) + tape.watch([x1, x2]) + # Stitch back x for `call` so tape records correct grads + x = tf.concat([x1, x2], axis=self.axis) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + y1, y2 = self.call(x, training=training, concat=False) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + + grads_combined = tape.gradient( + y2, [y1] + self.g.variables, output_gradients=[dy2]) + dy2_y1, dg = grads_combined[0], grads_combined[1:] + dy1_plus = dy2_y1 + dy1 + + grads_combined = tape.gradient( + y1, [x1, x2] + self.f.variables, output_gradients=[dy1_plus]) + dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] + dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] + + del tape + + grads = df + dg + vars_ = self.f.variables + self.g.variables + + return tf.concat([dx1, dx2], axis=self.axis), grads, vars_ + + +def _BottleneckResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True): + """Single bottleneck residual inner function contained in _Resdual. + + Corresponds to the `F`/`G` functions in the paper. + Suitable for training on ImageNet dataset. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused)) + model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + return model + + +def _ResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True): + """Single residual inner function contained in _ResdualBlock. + + Corresponds to the `F`/`G` functions in the paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused)) + model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + return model diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4436fd92506d54f1206fbfd424b897f9835657d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -0,0 +1,346 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +def _validate_block_call_channels_last(block_factory, test): + """Generic testing function for `channels_last` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = block_factory( + filters=64, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 224, 224, 64)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory( + filters=64, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 112, 112, 64)) + test.assertNotAllClose(y_tr, y_ev) + + +def _validate_block_call_channels_first(block_factory, test): + """Generic testing function for `channels_first` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + if not tf.test.is_gpu_available(): + test.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = block_factory(filters=64, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 64, 224, 224)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory(filters=64, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 64, 112, 112)) + test.assertNotAllClose(y_tr, y_ev) + + +class RevBlockTest(tf.test.TestCase): + + def test_call_channels_first(self): + """Test `call` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = blocks.RevBlock( + n_res=3, filters=64, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 64, 224, 224)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, filters=64, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, [16, 64, 112, 112]) + self.assertNotAllClose(y_tr, y_ev) + + def test_call_channels_last(self): + """Test `call` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = blocks.RevBlock( + n_res=3, + filters=64, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 224, 224, 64)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, + filters=64, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 112, 112, 64)) + self.assertNotAllClose(y_tr, y_ev) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + y = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + block = blocks.RevBlock( + n_res=3, filters=32, strides=(1, 1), input_shape=input_shape) + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + # Stride 2 + y = tf.random_normal(shape=(16, 32, 112, 112)) + dy = tf.random_normal(shape=(16, 32, 112, 112)) + block = blocks.RevBlock( + n_res=3, filters=32, strides=(2, 2), input_shape=input_shape) + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + def test_backward_grads_and_vars_channels_last(self): + """Test `backward` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + y = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + block = blocks.RevBlock( + n_res=3, + filters=32, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + # Stride 2 + y = tf.random_normal(shape=(16, 112, 112, 32)) + dy = tf.random_normal(shape=(16, 112, 112, 32)) + block = blocks.RevBlock( + n_res=3, + filters=32, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + +class _ResidualTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function. + + Varying downsampling and data format options. + """ + + _validate_block_call_channels_first(blocks._Residual, self) + _validate_block_call_channels_last(blocks._Residual, self) + + def test_backward_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (16, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = residual(x, training=True), residual(x, training=False) + x_ = residual.backward(y_tr, training=True) + # The numerical loss is alarming; reconstructed inputs could differ from + # the original inputs often by more than 1e-3 + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + x_ = residual.backward(y_ev, training=False) + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + + def test_backward_channels_last(self): + """Test `backward` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 16) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = residual(x, training=True), residual(x, training=False) + x_ = residual.backward(y_tr, training=True) + # Egregious numerical error + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + x_ = residual.backward(y_ev, training=False) + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward_grads` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (16, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, strides=(1, 1), input_shape=input_shape) + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( + x, dy=dy, training=True) + dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( + x, dy=dy, training=False) + self.assertNotAllClose(dx_tr, dx_ev) + self.assertTrue(isinstance(grads_tr, list)) + self.assertTrue(isinstance(grads_ev, list)) + self.assertTrue(isinstance(vars_tr, list)) + self.assertTrue(isinstance(vars_ev, list)) + for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, + vars_ev): + if grad_tr is not None: # Batch norm moving mean, var gives None grad + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + def test_backward_grads_and_vars_channels_last(self): + """Test `backward_grads` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 16) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( + x, dy=dy, training=True) + dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( + x, dy=dy, training=False) + self.assertNotAllClose(dx_tr, dx_ev) + self.assertTrue(isinstance(grads_tr, list)) + self.assertTrue(isinstance(grads_ev, list)) + self.assertTrue(isinstance(vars_tr, list)) + self.assertTrue(isinstance(vars_ev, list)) + for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, + vars_ev): + if grad_tr is not None: # Batch norm moving mean, var gives None grad + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + +class _ResidualInnerTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._ResidualInner, self) + _validate_block_call_channels_last(blocks._ResidualInner, self) + + +class _BottleneckResidualInner(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._BottleneckResidualInner, self) + _validate_block_call_channels_last(blocks._BottleneckResidualInner, self) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..495a78d550a48fa56d6cfa276e47c9ff846edff3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -0,0 +1,117 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports CIFAR-10, CIFAR-100, and ImageNet datasets. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def get_hparams_cifar_38(): + """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 32) + config.add_hparam("init_kernel", 3) + config.add_hparam("init_stride", 1) + config.add_hparam("n_classes", 10) + config.add_hparam("n_rev_blocks", 3) + config.add_hparam("n_res", [3, 3, 3]) + config.add_hparam("filters", [32, 64, 112]) + config.add_hparam("strides", [1, 2, 2]) + config.add_hparam("batch_size", 10) + config.add_hparam("bottleneck", False) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", False) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 32, 32)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (32, 32, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 2e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [40000, 60000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3]) + config.add_hparam("max_train_iter", 80000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("prefetch", True) + config.add_hparam("print_every", 50) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 500) + config.add_hparam("div255", True) + # For tf.data.Dataset + config.add_hparam("epochs", config.max_train_iter // config.batch_size) + + return config + + +def get_hparams_imagenet_56(): + """RevNet-56 configurations for ImageNet.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 128) + config.add_hparam("init_kernel", 7) + config.add_hparam("init_stride", 2) + config.add_hparam("n_classes", 1000) + config.add_hparam("n_rev_blocks", 4) + config.add_hparam("n_res", [2, 2, 2, 2]) + config.add_hparam("filters", [128, 256, 512, 832]) + config.add_hparam("strides", [1, 2, 2, 2]) + config.add_hparam("batch_size", 16) + config.add_hparam("bottleneck", True) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", True) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 224, 224)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (224, 224, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 1e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [160000, 320000, 480000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4]) + config.add_hparam("max_train_iter", 600000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("prefetch", True) + config.add_hparam("print_every", 50) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 500) + config.add_hparam("div255", True) + # For tf.data.Dataset + config.add_hparam("epochs", config.max_train_iter // config.batch_size) + + if config.bottleneck: + filters = [f * 4 for f in config.filters] + config.filters = filters + + return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed5d363e6c8bffd817357c006abee7ac0d1dbba --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Customized basic operations. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def downsample(x, filters, strides, axis=1): + """Downsample feature map with avg pooling, if filter size doesn't match.""" + + def pad_strides(strides, axis=1): + """Convert length 2 to length 4 strides. + + Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations + such as `tf.nn.avg_pool` use length 4 strides. + + Args: + strides: length 2 list/tuple strides for height and width + axis: integer specifying feature dimension according to data format + Returns: + length 4 strides padded with 1 on batch and channel dimension + """ + + assert len(strides) == 2 + + if axis == 1: + return [1, 1, strides[0], strides[1]] + return [1, strides[0], strides[1], 1] + + assert len(x.shape) == 4 and (axis == 1 or axis == 3) + + data_format = "NCHW" if axis == 1 else "NHWC" + strides_ = pad_strides(strides, axis=axis) + + if strides[0] > 1: + x = tf.nn.avg_pool( + x, strides_, strides_, padding="VALID", data_format=data_format) + + in_filter = x.shape[axis] + out_filter = filters + + if in_filter < out_filter: + pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2] + if axis == 1: + x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]]) + else: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size]) + # In case `tape.gradient(x, [x])` produces a list of `None` + return x + 0. diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc2641faf5a5d26262de683e52e36b1f42b3a7b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic ops used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops +tfe = tf.contrib.eager + + +class OpsTest(tf.test.TestCase): + + def test_downsample(self): + """Test `possible_down_sample` function with mock object.""" + + batch_size = 100 + # NHWC format + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + # HW doesn't change but number of features increased + y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) + self.assertEqual(y.shape, [batch_size, 32, 32, 5]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 3]) + # Number of feature increased and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 5]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + # HW doesn't change but feature map reduced + y = ops.downsample(x, filters=5, strides=(1, 1)) + self.assertEqual(y.shape, [batch_size, 5, 32, 32]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 3, 16, 16]) + # Both feature map and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 5, 16, 16]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3f7efe1b6d8a44ce1bef065f24fa5c35cd404a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -0,0 +1,263 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Code for main model. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import operator + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +# Global Conventions: +# 1) Default data format is NCWH, targeting GPU +# 2) Each block has attribute axis, inferred from data_format +# 3) Default training option to True for batch normalization +class RevNet(tf.keras.Model): + """RevNet that depends on all the blocks.""" + + def __init__(self, config): + """Initialize RevNet with building blocks. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(RevNet, self).__init__() + self.axis = 1 if config.data_format == "channels_first" else 3 + self.config = config + + self._init_block = self._construct_init_block() + self._block_list = self._construct_intermediate_blocks() + self._final_block = self._construct_final_block() + + def _construct_init_block(self): + init_block = tf.keras.Sequential( + [ + tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape), + tf.keras.layers.BatchNormalization( + axis=self.axis, fused=self.config.fused), + tf.keras.layers.LeakyReLU(alpha=0.) + ], + name="init") + if self.config.init_max_pool: + init_block.add( + tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + return init_block + + def _construct_final_block(self): + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + + final_block = tf.keras.Sequential( + [ + tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused), + tf.keras.layers.LeakyReLU(alpha=0.), # Vanilla ReLU + tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format), + tf.keras.layers.Dense(self.config.n_classes) + ], + name="final") + return final_block + + def _construct_intermediate_blocks(self): + # Precompute input shape after initial block + stride = self.config.init_stride + if self.config.init_max_pool: + stride *= 2 + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (self.config.init_filters, w // stride, h // stride) + else: + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // stride, h // stride, self.config.init_filters) + + # Aggregate intermediate blocks + block_list = tf.contrib.checkpoint.List() + for i in range(self.config.n_rev_blocks): + # RevBlock configurations + n_res = self.config.n_res[i] + filters = self.config.filters[i] + if filters % 2 != 0: + raise ValueError("Number of output filters must be even to ensure" + "correct partitioning of channels") + stride = self.config.strides[i] + strides = (self.config.strides[i], self.config.strides[i]) + + # Add block + rev_block = blocks.RevBlock( + n_res, + filters, + strides, + input_shape, + batch_norm_first=(i != 0), # Only skip on first block + data_format=self.config.data_format, + bottleneck=self.config.bottleneck, + fused=self.config.fused) + block_list.append(rev_block) + + # Precompute input shape for the next block + if self.config.data_format == "channels_first": + w, h = input_shape[1], input_shape[2] + input_shape = (filters, w // stride, h // stride) + else: + w, h = input_shape[0], input_shape[1] + input_shape = (w // stride, h // stride, filters) + + return block_list + + def call(self, inputs, training=True): + """Forward pass.""" + + # Only store hidden states during training + if training: + saved_hidden = [inputs] + + h = self._init_block(inputs, training=training) + if training: + saved_hidden.append(h) + + for block in self._block_list: + h = block(h, training=training) + if training: + saved_hidden.append(h) + + logits = self._final_block(h, training=training) + + return (logits, saved_hidden) if training else (logits, None) + + def compute_loss(self, logits, labels): + """Compute cross entropy loss.""" + + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + return tf.reduce_mean(cross_ent) + + def compute_gradients(self, inputs, labels, training=True): + """Manually computes gradients. + + Args: + inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + labels: One-hot labels for classification + training: for batch normalization + + Returns: + list of tuple each being (grad, var) for optimizer use + """ + + # Forward pass record hidden states before downsampling + _, saved_hidden = self.call(inputs, training=training) + + grads_all = [] + vars_all = [] + + # Manually backprop through last block + x = saved_hidden[-1] + with tf.GradientTape() as tape: + tape.watch(x) + logits = self._final_block(x, training=training) + cost = self.compute_loss(logits, labels) + + grads_combined = tape.gradient(cost, [x] + self._final_block.variables) + dy, grads_ = grads_combined[0], grads_combined[1:] + grads_all += grads_ + vars_all += self._final_block.variables + + # Manually backprop through intermediate blocks + for block in reversed(self._block_list): + y = saved_hidden.pop() + x = saved_hidden[-1] + dy, grads, vars_ = block.backward_grads_and_vars( + x, y, dy, training=training) + grads_all += grads + vars_all += vars_ + + # Manually backprop through first block + saved_hidden.pop() + x = saved_hidden.pop() + assert not saved_hidden # Cleared after backprop + + with tf.GradientTape() as tape: + y = self._init_block(x, training=training) # Recomputing + + grads_all += tape.gradient( + y, self._init_block.variables, output_gradients=[dy]) + vars_all += self._init_block.variables + + return grads_all, vars_all + + def train_step(self, + inputs, + labels, + optimizer, + global_step=None, + report=False): + """Train for one iteration.""" + + grads_all, vars_all = self.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads_all, vars_all), global_step=global_step) + + if report: + logits, _ = self.call(inputs, training=True) + loss = self.compute_loss(logits, labels) + + return loss + + def eval_step(self, inputs, labels): + """Evaluate.""" + + logits, _ = self.call(inputs, training=False) + preds = tf.cast(tf.argmax(logits, axis=1), tf.int32) + corrects = tf.cast(tf.equal(preds, labels), tf.float32) + accuracy = tf.reduce_mean(corrects) + + return accuracy diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..68502ceac2360e2b9ea965743d507439a09c3e59 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -0,0 +1,277 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.python.client import device_lib +tfe = tf.contrib.eager + + +class RevnetTest(tf.test.TestCase): + + def setUp(self): + super(RevnetTest, self).setUp() + config = config_.get_hparams_imagenet_56() + shape = (config.batch_size,) + config.input_shape + self.model = revnet.RevNet(config=config) + self.x = tf.random_normal(shape=shape) + self.t = tf.random_uniform( + shape=[config.batch_size], + minval=0, + maxval=config.n_classes, + dtype=tf.int32) + self.config = config + + def tearDown(self): + del self.model + del self.x + del self.t + del self.config + super(RevnetTest, self).tearDown() + + def test_call(self): + """Test `call` function.""" + + y, _ = self.model(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_compute_gradients(self): + """Test `compute_gradients` function.""" + + grads, vars_ = self.model.compute_gradients(inputs=self.x, labels=self.t) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) + + def test_train_step(self): + """Test `train_step` function.""" + + logits, _ = self.model(self.x, training=True) + loss = self.model.compute_loss(logits=logits, labels=self.t) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + + # Loss should be decreasing after each optimization step + for _ in range(3): + loss_ = self.model.train_step(self.x, self.t, optimizer, report=True) + self.assertTrue(loss_.numpy() <= loss.numpy()) + loss = loss_ + + def test_call_defun(self): + """Test `call` function with tfe.defun apply.""" + + y, _ = tfe.defun(self.model.call)(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_train_step_defun(self): + self.model.call = tfe.defun(self.model.call) + logits, _ = self.model(self.x, training=True) + loss = self.model.compute_loss(logits=logits, labels=self.t) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + + for _ in range(3): + loss_ = self.model.train_step(self.x, self.t, optimizer, report=True) + self.assertTrue(loss_.numpy() <= loss.numpy()) + loss = loss_ + + # Initialize new model, so that other tests are not affected + self.model = revnet.RevNet(config=self.config) + + +# Benchmark related +def device_and_data_format(): + return ("/gpu:0", + "channels_first") if tf.test.is_gpu_available() else ("/cpu:0", + "channels_last") + + +def random_batch(batch_size, config): + shape = (batch_size,) + config.input_shape + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32) + + return images, labels + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class RevnetBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for RevNet.""" + + def _train_batch_sizes(self): + """Shamelessly copied from `resnet50_test.py`. + + Note: This is targeted towards ImageNet. CIFAR-10 should allow more + aggressive batch sizes. + + Returns: + A tuple of possible batch sizes + """ + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == "GPU": + if "K20" in device.physical_device_desc: + return (16,) + if "P100" in device.physical_device_desc: + return (16, 32, 64) + if tf.DeviceSpec.from_string(device.name).device_type == "TPU": + return (32,) + return (16, 32) + + def _force_device_sync(self): + """Shamelessly copied from `resnet50_test.py`.""" + tf.constant(1.).cpu() + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format) + extras = {"examples_per_sec": batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _benchmark_eager_apply(self, + label, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = revnet.RevNet(config=config) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 10 + with tf.device(device): + images, _ = random_batch(batch_size, config) + for _ in range(num_burn): + model(images, training=False) + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in range(num_iters): + model(images, training=False) + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply( + "eager_apply_sync", device_and_data_format(), defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + "eager_apply_async", + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_call_defun(self): + self._benchmark_eager_apply( + "eager_apply_with_defun", device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, config) + model = revnet.RevNet(config=config) + optimizer = tf.train.GradientDescentOptimizer(0.1) + if defun: + model.call = tfe.defun(model.call) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in range(num_burn): + (images, labels) = iterator.next() + model.train_step(images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in range(num_iters): + (images, labels) = iterator.next() + model.train_step(images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train( + "eager_train_sync", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + "eager_train_async", + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_defun(self): + self._benchmark_eager_train( + "eager_train", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device("/device:CPU:0"): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + "eager_train_dataset_with_defun", + make_iterator, + device_and_data_format(), + defun=True) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 1937ffb583bc727df76470d072b35fb3c9acaa88..30d297a5fb2dd2f844093d790d051a79105984dd 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -117,7 +117,7 @@ py_library( py_test( name = "dnn_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 7ff25b95c079c7e06d29e874bcaa0d2c13e7167e..f1c60a912c8b1daa7db34f46e92bcc36ab300716 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,13 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator with warm-starting from a previous checkpoint. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + warm_start_from="/path/to/checkpoint/dir") + # Input builders def input_fn_train: # returns x, y pass @@ -92,7 +99,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=nn.relu, dropout=None, input_layer_partitioner=None, - config=None): + config=None, + warm_start_from=None): """Initializes a `DNNEstimator` instance. Args: @@ -116,6 +124,11 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. + warm_start_from: A string filepath to a checkpoint to warm-start from, or + a `WarmStartSettings` object to fully configure warm-starting. If the + string filepath is provided instead of a `WarmStartSettings`, then all + weights are warm-started, and it is assumed that vocabularies and Tensor + names are unchanged. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -131,4 +144,5 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner=input_layer_partitioner, config=config) super(DNNEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 75e3107670d658e55ce23d983e47311f1c180104..050b0428bf7b685229e12561cfb0682d931299d2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache -def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): +def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( @@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): *args, **kwargs) +def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Returns a DNNEstimator that uses multi_class_head.""" + return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes), + *args, **kwargs) + + class DNNEstimatorEvaluateTest( dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): @@ -75,6 +81,15 @@ class DNNEstimatorTrainTest( self, _dnn_estimator_fn) +class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNWarmStartingTest.__init__( + self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) + + class DNNEstimatorIntegrationTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index b798769d2cfde69e9e0b8d65882a07d038cbb994..9594e5132fd20dadea118fd1dd6768feb7fd7fff 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -529,6 +529,7 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 89b5f4c4137f6c42417f539a578fd8b11f8b235d..45d7b740462ca21139e2e93e34b43668f1e08a94 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -110,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) def test_embedding_column_with_non_sequence_categorical(self): - """Tests that error is raised for non-sequence categorical column.""" + """Tests that error is raised for non-sequence embedding column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] @@ -132,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[embedding_column_a]) + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + def test_indicator_column(self): vocabulary_size_a = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -578,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + class SequenceIndicatorColumnTest(test.TestCase): def test_get_sequence_dense_tensor(self): diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 40ae01bfcce1dde580e6a5f6d9c8ec1aa1abb83f..e8e318001972934c7d2154bc14744823a3ba09f9 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -712,7 +712,8 @@ class VariableDeviceChooser(object): num_tasks=0, job_name='ps', device_type='CPU', - device_index=0): + device_index=0, + replica=None): """Initialize VariableDeviceChooser. Usage: @@ -733,12 +734,15 @@ class VariableDeviceChooser(object): self._job_name = job_name self._device_type = device_type self._device_index = device_index + self._replica = replica self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(device_type=self._device_type, - device_index=self._device_index) + device_spec = tf_device.DeviceSpec( + replica=self._replica, + device_type=self._device_type, + device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 37ea6eb12aba7d25656f19cbbc86475c1228d916..7e0c7dbec1d9266b53a169fe83b88d1e3af77d04 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -506,6 +506,35 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableWithVariableDeviceChooserWithReplica(self): + + with ops.Graph().as_default(): + device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2) + with arg_scope([variables_lib2.variable], device=device_fn): + a = variables_lib2.variable('a', []) + b = variables_lib2.variable('b', []) + c = variables_lib2.variable('c', [], device='cpu:12') + d = variables_lib2.variable('d', []) + with ops.device('cpu:99'): + e_init = constant_op.constant(12) + e = variables_lib2.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(a.initial_value.op.colocation_groups(), + a.op.colocation_groups()) + self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertEqual(b.initial_value.op.colocation_groups(), + b.op.colocation_groups()) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertEqual(c.initial_value.op.colocation_groups(), + c.op.colocation_groups()) + self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(d.initial_value.op.colocation_groups(), + d.op.colocation_groups()) + self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableGPUPlacement(self): with ops.Graph().as_default(): @@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 @@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase): # Tests restoring PartitionedVariables and tests using a dictionary # of lists as the assign_from_checkpoint() var_list param. def testLoadPartitionedVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_partitioned_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables')) init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]]) init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. @@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase): partitioner = partitioned_variables.variable_axis_size_partitioner(2) var0 = variables_lib2.variable( 'var0', shape=init_value0.shape, partitioner=partitioner) - var0full = variables_lib2.variable( - 'var0full', shape=init_value0.shape) + var0full = variables_lib2.variable('var0full', shape=init_value0.shape) var1 = variables_lib2.variable( 'var1', shape=init_value1.shape, partitioner=partitioner) # Convert var0 and var1 into a list of underlying variables. vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase): # Request and test the variable values. PartitionedVariables can't # be evaled so we wrap them in an identity. - self.assertTrue(np.array_equal( - init_value0, array_ops.identity(var0).eval())) - self.assertTrue(np.array_equal( - init_value0, var0full.eval())) - self.assertTrue(np.array_equal( - init_value1, array_ops.identity(var1).eval())) + self.assertTrue( + np.array_equal(init_value0, + array_ops.identity(var0).eval())) + self.assertTrue(np.array_equal(init_value0, var0full.eval())) + self.assertTrue( + np.array_equal(init_value1, + array_ops.identity(var1).eval())) def testRaisesValueErrorIfAVariableIsntFound(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 @@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) def testInitFromCheckpointWithScopes(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) @@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_existing_vars_no_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testLoadExistingVariablesDifferentShapeAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), - 'load_existing_variables_different_shape_allow_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join( + self.get_temp_dir(), + 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testNotFoundError(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'not_found_error')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testMissingVariablesList(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_list')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testMissingVariablesDict(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_dict')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase): def testZeroInitializer(self): for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): for use_init in (False, True): - self._testZeroInitializer( - [10, 20], array_ops.ones( - [10, 20], dtype=dtype), use_init) + self._testZeroInitializer([10, 20], array_ops.ones( + [10, 20], dtype=dtype), use_init) class ZeroVarInitializerOpTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 4d62ac65ff619f98a18387058fdc8a0eade0d8f8..a955e21b72e765f751318c7927f9644481fe7933 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -33,6 +35,13 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +def NoMemoryOptimizationConfig(): + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + return config + + def GetShrunkInceptionShapes(shrink=10): """Iterator for smaller versions of convolution shapes in 2015 Inception. @@ -193,7 +202,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): # This is to guarantee that there is always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - with self.test_session(use_gpu=True): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2 @@ -241,7 +251,9 @@ class FusedConv2DBiasActivationTest(test.TestCase): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=use_gpu, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) @@ -865,7 +877,9 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale, conv_input, kernel, padding_type, strides, side_input_scale, side_input, biases) - with self.test_session(use_gpu=True) as sess: + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=True, config=NoMemoryOptimizationConfig()) as sess: actual_y, expected_y = sess.run([actual, expected]) tf_logging.info("actual_y = ", actual_y) tf_logging.info("expected_y = ", expected_y) diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 3ba1026383ef146adb32197ae41b5c251155bf46..2ede5daee74223e812cc29e9708b1989b698fb4e 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor) + map_op = functional_ops.map_fn( + tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) return core.identity(map_lt, name=scope) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 06060b99e7e58787994f20f037ffa451abbc7459..a85cff4f7098e9a5eedca1b0c8c0cb42e172d90a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples( the serialized proto. Returns: - A tuple consisting of: - context_features: a dict mapping `FeatureColumns` from - `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. - sequence_features: a dict mapping `FeatureColumns` from - `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + A tuple consisting of (context_features, sequence_features) + + * context_features: a dict mapping `FeatureColumns` from + `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + * sequence_features: a dict mapping `FeatureColumns` from + `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. """ # Sequence example parsing requires a single (scalar) example. try: diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 541da9061732ad271f6d5456446a9c30b81e58dd..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -505,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d10927a0cdd5c67c8d2a8e569153235ee175ec4d..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index cc8a8035d1dadeec98886ba1dae4cdf403f26de4..2b6997146e1e5a3873ed0f94a9221b34bed7621d 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -70,6 +70,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. MINIMAL_PATH := $(BINDIR)minimal +# Benchmark static library and binary +BENCHMARK_LIB_NAME := benchmark-lib.a +BENCHMARK_BINARY_NAME := benchmark_model +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) + MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc MINIMAL_OBJS := $(addprefix $(OBJDIR), \ @@ -78,12 +84,19 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. +PROFILER_SRCS := \ + tensorflow/contrib/lite/profiling/time.cc +PROFILE_SUMMARIZER_SRCS := \ + tensorflow/contrib/lite/profiling/profile_summarizer.cc \ + tensorflow/core/util/stats_calculator.cc + CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(PROFILER_SRCS) \ $(wildcard tensorflow/contrib/lite/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ @@ -107,18 +120,31 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) LIB_OBJS := $(TF_LITE_CC_OBJS) + +# Benchmark sources +BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark +BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \ + $(PROFILE_SUMMARIZER_SRCS) + +BENCHMARK_SRCS := $(filter-out \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ + $(BENCHMARK_ALL_SRCS)) + +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.c @mkdir -p $(dir $@) $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @@ -131,6 +157,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) + +benchmark_lib: $(BENCHMARK_LIB) +$(info $(BENCHMARK_BINARY)) +$(BENCHMARK_BINARY) : $(BENCHMARK_LIB) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_BINARY) \ + $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) + +benchmark: $(BENCHMARK_BINARY) + # Gets rid of all generated files. clean: rm -rf $(MAKEFILE_DIR)/gen diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4f836d367747e06de682b5764206d33f6e2fb983..22be64d6ff649b4bff45a5e5680984d688a8cf38 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -31,7 +31,7 @@ struct AllocationInfo { // The tensor index to be allocated or deallocated. int tensor; // Whether to allocate or deallocate - enum { ALLOC, DEALLOC } type; + enum Type { ALLOC, DEALLOC } type; }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, @@ -67,6 +67,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Keeps track of references to each tensor. std::vector refcounts(graph_info_->num_tensors(), 0); + // `allocated` and `deallocated` are technically list of boolean values. + // We're saving the compiled binary size by using `vector`. + std::vector allocated(graph_info_->num_tensors(), false); + std::vector deallocated(graph_info_->num_tensors(), false); + + auto allocate = [this, &allocated, &deallocated](int node, + int tensor) -> TfLiteStatus { + if (allocated[tensor]) { + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC}); + allocated[tensor] = true; + return kTfLiteOk; + }; + + auto deallocate = [this, &allocated, &deallocated]( + int node, int tensor) -> TfLiteStatus { + if (!allocated[tensor]) { + // Do not enqueue a DEALLOC if the tensor is never allocated. + // This happened with the constant tensors. + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC}); + return kTfLiteOk; + }; // There will be an entry in alloc_queue_ for the allocation of each tensor // and another for their deallocation. @@ -79,6 +106,28 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } + // Variable tensors should are also never overwritten and need to be alive all + // the time. + for (int tensor_index : graph_info_->variables()) { + refcounts[tensor_index]++; + } + + // Queue all graph inputs for allocation. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + + // Queue all graph variable tensors for allocation. + for (int tensor_index : graph_info_->variables()) { + if (tensor_index != kOptionalTensor) { + // Increase the reference count for input tensors by one, so it will + // never be deallocated. + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + // Count references to node input tensors. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -94,10 +143,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Queue all graph inputs for allocation. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { - alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } - // Go through the graph in execution order. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -106,7 +154,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { TfLiteIntArray* node_outputs = node.outputs; for (int j = 0; j < node_outputs->size; ++j) { int tensor_index = node_outputs->data[j]; - alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(i, tensor_index)); } // Then update the ref-counts of the node's inputs, and if necessary queue @@ -117,7 +165,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { if (tensor_index != kOptionalTensor) { refcounts[tensor_index]--; if (refcounts[tensor_index] == 0) { - alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); } } } diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441..f0fd35216f645df59b03340e00daca9322721b1b 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -100,12 +100,18 @@ class TestGraph { std::vector* tensors() { return &tensors_; } const std::vector& inputs() { return inputs_; } const std::vector& outputs() { return outputs_; } + const std::vector& variables() { return variables_; } + + void SetVariables(const std::vector& variables) { + variables_ = variables; + } private: std::vector nodes_; std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // The GraphInfo for a TestGraph. @@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo { } const std::vector& inputs() const override { return graph_->inputs(); } const std::vector& outputs() const override { return graph_->outputs(); } + const std::vector& variables() const override { + return graph_->variables(); + } private: TestGraph* graph_; @@ -209,11 +218,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) { TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); (*graph.tensors())[1].bytes = 0; SetGraph(&graph); - // TODO(ahentz): this is currently broken because the arena finds two - // allocations with the same offset and returns an error. - ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); - // EXPECT_EQ(GetOffset(1), 0); - // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk); + EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr); } TEST_F(ArenaPlannerTest, SimpleGraph) { @@ -309,13 +315,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { { /* in, out, tmp */ {{0, 1}, {2}, {}}, // First op - {{2, 0}, {4}, {5}}, // Second op, with temporary + {{2, 0}, {4}, {5}}, // Second op, with persistent {{4, -1}, {3}, {}} // Third op, with optional }, {3}); // Make #1 persistent so it goes into its own arena. (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + // The only use case for kTfLiteArenaRwPersistent is variable tensor now. + graph.SetVariables({1}); SetGraph(&graph); Execute(0, 10); diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 13d9a463fb9516cd96fa638ebad84ddccf1b59f2..612813caee880f3f7291ee9850f7d8f842d598a6 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -201,7 +201,7 @@ def generated_test_models(): "concat", "constant", "control_dep", - # "conv", + "conv", "depthwiseconv", "div", "equal", @@ -220,6 +220,7 @@ def generated_test_models(): "less_equal", "local_response_norm", "log_softmax", + "log", "lstm", "max_pool", "maximum", diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7b10b69f438536709f47ac8bd5cb2f8e27d0a1aa..aef9a92883f18dabfc36058507d739856c3c2af7 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -98,6 +98,7 @@ typedef enum { kTfLiteBuiltinExpandDims = 70, kTfLiteBuiltinEqual = 71, kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index 5c6f5e72a47180cd98be46f60cfa8eaf28197806..7f2aa316f4a9a265b14a216a6ffa53c7f0757426 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; @@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, tensor->bytes = size; tensor->allocation_type = allocation_type; tensor->allocation = allocation; + tensor->is_variable = is_variable; } void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 4eb66cc225eb04923be9aaa445a335ad822c8a6f..15a37de9dc665ff147b7094a61a5afab701932ce 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -138,6 +138,7 @@ typedef enum { kTfLiteInt64 = 4, kTfLiteString = 5, kTfLiteBool = 6, + kTfLiteInt16 = 7, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -148,7 +149,7 @@ typedef struct { int32_t zero_point; } TfLiteQuantizationParams; -// A union of points that points to memory for a given tensor. +// A union of pointers that points to memory for a given tensor. typedef union { int* i32; int64_t* i64; @@ -157,6 +158,7 @@ typedef union { const char* raw_const; uint8_t* uint8; bool* b; + int16_t* i16; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -223,6 +225,9 @@ typedef struct { // delegate buffer. // WARNING: This is an // experimental interface that is subject to change. bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; } TfLiteTensor; // Free data memory of tensor `t`; @@ -235,7 +240,8 @@ void TfLiteTensorFree(TfLiteTensor* t); void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor); + const void* allocation, bool is_variable, + TfLiteTensor* tensor); // Resize the allocated data of a (dynamic) tensor. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 9322e186a280e932a2441ab16ac8579d9ab67ee2..c61445114ecc6dfbe4f2b6ab666b28a8aa746be3 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -53,19 +53,18 @@ cc_library( ], ) -# TODO(ahentz): Test disabled as it has a memory leek from read_bmp -# cc_test( -# name = "label_image_test", -# srcs = [ -# "get_top_n.h", -# "get_top_n_impl.h", -# "label_image_test.cc", -# ], -# data = [ -# "testdata/grace_hopper.bmp", -# ], -# deps = [ -# ":bitmap_helpers", -# "//testing/base/public:gunit", -# ], -# ) +cc_test( + name = "label_image_test", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image_test.cc", + ], + data = [ + "testdata/grace_hopper.bmp", + ], + deps = [ + ":bitmap_helpers", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc index 0b38cd38c83927c65d251b9356301b6bef7521f2..2735d1f5ea4e2a104f71a3a6f874d9acb2f48142 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -28,8 +28,9 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, - int width, int height, int channels, bool top_down) { +std::vector decode_bmp(const uint8_t* input, int row_size, int width, + int height, int channels, bool top_down) { + std::vector output(height * width * channels); for (int i = 0; i < height; i++) { int src_pos; int dst_pos; @@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, } } } - return output; } -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s) { +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); @@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, if (s->verbose) LOG(INFO) << "len: " << len << "\n"; - const uint8_t* img_bytes = new uint8_t[len]; + std::vector img_bytes(len); file.seekg(0, std::ios::beg); - file.read((char*)img_bytes, len); + file.read(reinterpret_cast(img_bytes.data()), len); const int32_t header_size = - *(reinterpret_cast(img_bytes + 10)); - *width = *(reinterpret_cast(img_bytes + 18)); - *height = *(reinterpret_cast(img_bytes + 22)); - const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *(reinterpret_cast(img_bytes.data() + 10)); + *width = *(reinterpret_cast(img_bytes.data() + 18)); + *height = *(reinterpret_cast(img_bytes.data() + 22)); + const int32_t bpp = + *(reinterpret_cast(img_bytes.data() + 28)); *channels = bpp / 8; if (s->verbose) @@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, bool top_down = (*height < 0); // Decode image, allocating tensor once the image size is known - uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; const uint8_t* bmp_pixels = &img_bytes[header_size]; - return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), - *channels, top_down); + return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels, + top_down); } } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 97343dde6b31694e5b2de20b35a7083fb8fe4a0e..5fc75b1f7274c14d49e4a26d6ce4902c037afa6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -22,8 +22,8 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s); +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s); template void resize(T* out, uint8_t* in, int image_height, int image_width, diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 966fcd2a31fd4d4ff2c3e91633550a8effa81ee8..86d7d1cc4a625243791d5e7d5b746526a58efb6d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -138,8 +138,8 @@ void RunInference(Settings* s) { int image_width = 224; int image_height = 224; int image_channels = 3; - uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, - &image_channels, s); + std::vector in = read_bmp(s->input_bmp_name, &image_width, + &image_height, &image_channels, s); int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; @@ -168,12 +168,12 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); + resize(interpreter->typed_tensor(input), in.data(), + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - resize(interpreter->typed_tensor(input), in, + resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc index ce35483f76e8f40ced79e1ee30774c62d0eba94e..de7de21f7741d3d46cb96e793e8bc4bfb21384fe 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -27,20 +27,20 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + "tensorflow/contrib/lite/examples/label_image/testdata/" + "grace_hopper.bmp"; int height, width, channels; Settings s; - uint8_t *data; - - data = read_bmp(lena_file, &width, &height, &channels, &s); + std::vector input = + read_bmp(lena_file, &width, &height, &channels, &s); ASSERT_EQ(height, 606); ASSERT_EQ(width, 517); ASSERT_EQ(channels, 3); - uint8_t *out = new uint8_t[606 * 517 * 3]; - downsize(out, data, 606, 517, 3, 214, 214, 3, &s); - ASSERT_EQ(out[0], 0x15); - ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); + std::vector output(606 * 517 * 3); + resize(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(output[0], 0x15); + ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11); } TEST(LabelImageTest, GetTopN) { diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2f797e6c5b05f52bec9fc34f1b8011aca70330 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -0,0 +1,206 @@ +# TensorFlow Lite Ops Versioning + +This document describes TensorFlow Lite's op versioning schema. Op +versioning enables developers to add new functionalities and parameters into +existing ops. In addition, it guarantees the following: + +* Backward compatibility: New TensorFlow Lite implementation should + handle an old model file. +* Forward compatibility: Old TensorFlow Lite implementation should + handle a new model file produced by new version of TOCO, as long as no new + features are used. +* Forward in-compatibility detection: If an old TensorFlow Lite implementation + reads a new model that contains a new version of an op which isn't + supported, it should report the error. + +## Example: Adding Dilation into Convolution + +The remainder of this document explains op versioning in TFLite by showing how +to add dilation parameters to the convolution operation. + +Knowledge of dilation is not required to understand this document. Note that: + +* 2 new integer parameters will be added: `dilation_width_factor` and + `dilation_height_factor`. +* Old convolution kernels that don't support dilation are equivalent to + setting the dilation factors to 1. + +### Change FlatBuffer Schema + +To add new parameters into an op, change the options table in +`lite/schema/schema.fbs`. + +For example, the options table of convolution looks like this: + +``` +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} +``` + +When adding new parameters: + +* Add comments indicating which parameters are supported by which version. +* When the new implementation gets the default values for newly added + parameters, it should work exactly the same as the old implementation. + +The table will be like this after the new parameters are added: + +``` +table Conv2DOptions { + // Parameters supported by version 1: + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + + // Parameters supported by version 2: + dilation_width_factor:int = 1; + dilation_height_factor:int = 1; +} +``` + +### Change C Structures and Kernel Implementation + +In TensorFlow Lite, the kernel implementation is decoupled from +FlatBuffer definition. The kernels read the parameter from C structures defined +in `lite/builtin_op_data.h`. + +The original convolution parameter is as follows: + +``` +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; +``` + +As with the FlatBuffer schema, add comments indicating which parameters are +supported starting from which version. The result is seen below: + +``` +typedef struct { + // Parameters supported by version 1: TfLitePadding padding; int + stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters supported by version 2: + int dilation_width_factor; + int dilation_height_factor; +} TfLiteConvParams; +``` + +Please also change the kernel implementation to read the newly added parameters +from the C structures. The details are omitted here. + +### Change the FlatBuffer Reading Code + +The logic to read FlatBuffer and produce C structure is in `lite/model.cc`. + +Update the file to handle the new parameters, as shown below: + +``` +case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_width_factor(); + params->dilation_height_factor = conv_params->dilation_height_factor(); + } + *builtin_data = reinterpret_cast(params); + break; +} +``` + +It's not required to check the op version here. When the new implementation +reads an old model file where dilation factors are missing, it will use 1 as +the default value, and the new kernel will work consistently with the old +kernel. + +### Change Kernel Registration + +The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions +to register op kernels. The minimum and maximum version are 1 by default: + +``` +void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +``` + +The built-in ops are registered in `lite/kernels/register.cc`. In this example, +we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we +need to change this line: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); +``` + +to: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2); +``` + +### Change TOCO TFLite exporter + +The last step is to make TOCO populate the minimum version that's required to +execute the op. In this example, it means: + +* Populate version=1 when dilation factors are all 1. +* Populate version=2 otherwise. + +To do this, you need to override `GetVersion` function for the operator class in +`lite/toco/tflite/operator.cc`. + +For ops with only one version, the `GetVersion` function is defined as: + +``` +int GetVersion(const Operator& op) const override { return 1; } +``` + +When supporting multiple versions, check the parameters and determine the +version for the op, as shown in the following example: + +``` +int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; +} +``` + +### Delegation Implementation + +TensorFlow Lite provides a delegation API which enables delegating ops to +hardware backends. In Delegate's `Prepare` function, check if the version +is supported for every node in Delegation code. + +``` +const int kMinVersion = 1; +TfLiteNode* node; +TfLiteRegistration; +context->GetNodeAndRegistration(context, node_index, &node, ®istration); + +if (registration->version > kMinVersion) { + // Reject the node if the version isn't supported. +} +``` + +This is required even if the delegation only supports version 1 ops, so the +delegation can detect incompatibility when getting a higher version op. + diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index a991a239d92d847af2ae17acab472dd823ad236f..965273f0f04d33b52903c0551fff3533c31d3bd8 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -429,6 +429,17 @@ Outputs { } ``` +**LOG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to log(input) +} +``` + **LOG_SOFTMAX** ``` diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 313af5fb7574b42bcdd53b4baad06e4ccfb34053..77268d7aebe9ebfb33b9f35b319d34e6de8324ee 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -46,6 +46,9 @@ class GraphInfo { // Returns the indices of the output tensors. virtual const std::vector& outputs() const = 0; + + // Returns the indices of the variable tensors. + virtual const std::vector& variables() const = 0; }; // Represents a subgraph of a TensorFlow Lite graph. diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc index ea38b43993fef71c6820c7a978351d92d5420287..89a8f36b416b5dec54c1e374cdcdae3ab9ab0cde 100644 --- a/tensorflow/contrib/lite/graph_info_test.cc +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } const std::vector& inputs() const override { return inputs_; } const std::vector& outputs() const override { return outputs_; } + const std::vector& variables() const override { return variables_; } void AddNode(const std::vector& inputs, const std::vector& outputs) { @@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo { std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // Partition a graph to generate a list of subgraphs. This wraps the API call diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index ebb0aedc2001a86b7fcff67ef8703b5e4a845818..3287f9c4fdeeb8949e6fa15f4ec8c0aca2dd8a08 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo { const std::vector& outputs() const override { return interpreter_->outputs(); } + const std::vector& variables() const override { + return interpreter_->variables(); + } public: Interpreter* interpreter_; @@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { return kTfLiteOk; } +TfLiteStatus Interpreter::SetVariables(std::vector variables) { + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(), + variables.size())); + variables_ = std::move(variables); + return kTfLiteOk; +} + TfLiteStatus Interpreter::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -334,6 +344,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteFloat32: *bytes = sizeof(float) * count; break; + case kTfLiteInt16: + *bytes = sizeof(int16_t) * count; + break; case kTfLiteInt32: *bytes = sizeof(int32_t) * count; break; @@ -347,9 +360,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, *bytes = sizeof(bool) * count; break; default: - ReportError( - &context_, - "Only float32, int32, int64, uint8, bool supported currently."); + ReportError(&context_, + "Only float32, int16, int32, int64, uint8, bool supported " + "currently."); return kTfLiteError; } return kTfLiteOk; @@ -367,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() { } TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + if (state_ == kStateUninvokable) { state_ = kStateInvokable; } @@ -375,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() { return kTfLiteOk; } +// TODO(ycling): Consider to provide other functions to initialize variable +// tensors to non-zero values. +TfLiteStatus Interpreter::ResetVariableTensorsToZero() { + for (auto& tensor : tensors_) { + if (!tensor.is_variable) { + continue; + } + + // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be + // allocated after the initial `PrepareOpsAndTensors()` is called. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, + kTfLiteArenaRwPersistent); + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + + memset(tensor.data.raw, 0, tensor.bytes); + } + return kTfLiteOk; +} + TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector& inputs, const std::vector& outputs, const char* init_data, size_t init_data_size, void* builtin_data, @@ -687,7 +720,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, &tensor); + kTfLiteMmapRo, allocation, false, &tensor); } return kTfLiteOk; } @@ -698,7 +731,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( &context_, @@ -716,11 +749,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims, rank, &required_bytes)); } + + TfLiteAllocationType allocation_type = kTfLiteArenaRw; + if (type == kTfLiteString) { + if (is_variable) { + // We don't have a real use case for string variable tensor. + ReportError(&context_, "String variable tensor isn't supported."); + return kTfLiteError; + } + allocation_type = kTfLiteDynamic; + } else if (is_variable) { + allocation_type = kTfLiteArenaRwPersistent; + } + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, - /*buffer=*/nullptr, required_bytes, - type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, - nullptr, &context_.tensors[tensor_index]); + /*buffer=*/nullptr, required_bytes, allocation_type, + nullptr, is_variable, &context_.tensors[tensor_index]); return kTfLiteOk; } @@ -736,7 +781,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteDynamic) { + tensor->allocation_type == kTfLiteDynamic || + tensor->allocation_type == kTfLiteArenaRwPersistent) { if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7315d8360680ca0d3c405dc80b593762275815ee..37961cd1dc97607510edc9e6f0141c8bfc431c0d 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -118,6 +118,11 @@ class Interpreter { // interpreter. TfLiteStatus SetOutputs(std::vector outputs); + // Provide a list of tensor indexes that are variable tensors. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetVariables(std::vector variables); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' @@ -160,13 +165,15 @@ class Interpreter { // to Interpreter. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization) { + const std::vector& dims, TfLiteQuantizationParams quantization, + bool is_variable = false) { return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization); + dims.data(), quantization, is_variable); } TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization); + const int* dims, TfLiteQuantizationParams quantization, + bool is_variable = false); // Functions to access tensor data @@ -182,6 +189,9 @@ class Interpreter { // Read only access to list of outputs. const std::vector& outputs() const { return outputs_; } + // Read only access to list of variable tensors. + const std::vector& variables() const { return variables_; } + // Return the name of a given output. The given index must be between 0 and // outputs().size(). const char* GetOutputName(int index) const { @@ -379,6 +389,10 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } + // Reset all variable tensors to zero. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus ResetVariableTensorsToZero(); + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -541,6 +555,9 @@ class Interpreter { // interpreter. std::vector outputs_; + // Array of indices representing the tensors that are variable tensors. + std::vector variables_; + // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 453c1ada1cf6263be14a3b170f209e3a30580cc3..b977cb089c39e3904d1d9f83551fc401e82663d8 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -106,10 +106,9 @@ TEST(BasicInterpreter, CheckAllocate) { TfLiteType type; size_t size; } cases[] = { - {kTfLiteFloat32, sizeof(float)}, - {kTfLiteInt32, sizeof(int32_t)}, - {kTfLiteUInt8, sizeof(uint8_t)}, - {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteInt16, sizeof(int16_t)}, }; for (auto test : cases) { @@ -134,6 +133,7 @@ TEST(BasicInterpreter, CheckResize) { const int32_t int32s[] = {-3, -4}; const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; + const int16_t int16s[] = {8, -9}; struct { TfLiteType type; @@ -144,6 +144,7 @@ TEST(BasicInterpreter, CheckResize) { {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, }; for (auto test : cases) { @@ -179,10 +180,8 @@ TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; } cases[] = { - {kTfLiteFloat32}, - {kTfLiteInt32}, - {kTfLiteUInt8}, - {kTfLiteInt64}, + {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, + {kTfLiteInt64}, {kTfLiteInt16}, }; for (auto test : cases) { @@ -211,7 +210,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; std::vector sizes{2048, 4096, 1023, 2047, 1021, - 2047, 1023, 2046, 1021, 2048}; + 2047, 1023, 2046, 0, 2048}; for (int i = 0; i < sizes.size(); ++i) { interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, quant); @@ -228,6 +227,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); @@ -314,6 +314,18 @@ TEST(BasicInterpreter, ResizingTensors) { EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 1 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // TODO(ahentz): We shouldn't have to force reallocation, but // ResizeInputTensor doesn't realloc dynamic tensors. Also note that // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d..e3cea19e1683ac2680521bce66d1328e4b2caf1c 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -1,5 +1,14 @@ # TF Lite Android App +## Building in Android Studio with TensorFlow Lite AAR from JCenter. +The build.gradle is configured to use TensorFlow Lite's nightly build. + +If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is +undefined for type Interpreter), there has likely been a backwards compatible +change to the API. You will need to pull new app code that's compatible with the +nightly build and may need to first wait a few days for our external and internal +code to merge. + ## Building from Source with Bazel 1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index b76eaad8bb91224805d16b3d6f7c3274c9feb90c..44ea2dcd908644bcfc637f71573ce722adaf6935 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -52,7 +52,43 @@ dependencies { compile 'com.android.support:support-annotations:25.3.1' compile 'com.android.support:support-v13:25.2.0' - compile 'org.tensorflow:tensorflow-lite:+' + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' testCompile 'junit:junit:4.12' } + +def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip" +def targetFolder = "src/main/assets" + +task downloadModel(type: DownloadUrlTask) { + doFirst { + println "Downloading ${modelDownloadUrl}" + } + sourceUrl = "${modelDownloadUrl}" + target = file("${localCache}") +} + +task unzipModel(type: Copy, dependsOn: 'downloadModel') { + doFirst { + println "Unzipping ${localCache}" + } + from zipTree("${localCache}") + into "${targetFolder}" +} + +// Ensure the model file is downloaded and extracted before every build +preBuild.dependsOn unzipModel + +class DownloadUrlTask extends DefaultTask { + @Input + String sourceUrl + + @OutputFile + File target + + @TaskAction + void download() { + ant.get(src: sourceUrl, dest: target) + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 7ca1e35489cba3b5d2567bc04e532fedf8a527a7..443ce8924a43669fb264e19561c733d7e3436cb0 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -126,16 +126,19 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index ee42e5cdc838fac4bf9a3de15b7e95e001588907..14b399ef96eab1d5066a22a7eb95ab061e8ba2bc 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are @@ -255,8 +257,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 0bd504695074011efd946f4c4d1f8d4854e82730..98c21ce9d390aaa1f3cb5fdb8f31cbffb1b81d6a 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -23,7 +23,7 @@ namespace ops { namespace builtin { namespace elementwise { -TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); @@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } -TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { const float* in = GetTensorData(input); const float* in_end = in + elements; float* out = output->data.f; - for (; in < in_end; in++, out++) *out = std::sin(*in); + for (; in < in_end; in++, out++) *out = float_func(*in); return kTfLiteOk; } default: { @@ -55,14 +56,28 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare, + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, elementwise::SinEval}; return &r; } +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 412ffb04b90fbc24d232d25d2a86ce639752c3e8..10e88d5a31868eeb5f65c7ade1f1c73827dea24a 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,12 +24,13 @@ namespace { using ::testing::ElementsAreArray; -class SinOpModel : public SingleOpModel { +class ElementWiseOpModel : public SingleOpModel { public: - SinOpModel(std::initializer_list input_shape) { + ElementWiseOpModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } @@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel { }; TEST(ElementWise, Sin) { - SinOpModel m({1, 1, 4, 1}); + ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -50,6 +51,15 @@ TEST(ElementWise, Sin) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Log) { + ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 7539c0b30ded921df957217bebdc7b20ea4b40b4..9410bead5e7a68363d034c22fb2c0eff9f060ef1 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -24,7 +24,8 @@ limitations under the License. // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row -// Each item in output is a raw bytes copy of corresponding item in input. +// Each item in output is a raw bytes copy of the corresponding item in input, +// or a dequantized value in the case of a uint8 input. // When indices are out of bound, the ops will not succeed. // @@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* lookup = GetInput(context, node, 0); - const TfLiteTensor* value = GetInput(context, node, 1); - +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { + const int row_size = SizeOfDimension(value, 0); + const double scaling_factor = 1.0 / value->params.scale; + + // col_size after we flatten tensor into 2D. + int col_size = 1; + for (int i = 1; i < NumDimensions(value); i++) { + col_size *= SizeOfDimension(value, i); + } + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + // Dequantize embedding values. + // TODO(alanchiao): refactor scalar multiply into separate function + // for ease of adding a neon equivalent if ever necessary. + for (int j = 0; j < col_size; j++) { + output->data.f[j + i * col_size] = + value->data.uint8[j + idx * col_size] * scaling_factor; + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* value = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (value->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, lookup, value, output); + case kTfLiteUInt8: + return EvalHybrid(context, node, lookup, value, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +} + } // namespace embedding_lookup TfLiteRegistration* Register_EMBEDDING_LOOKUP() { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 9b501878f196216a61568bfa36e6615f4dd07478..04657fd86323ef1c58d069c06097c7665f55cc87 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -7,13 +7,14 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License +for the specific language governing permissions and limitations under the +License. ==============================================================================*/ // Unit test for TFLite Lookup op. +#include #include #include @@ -29,12 +30,13 @@ namespace { using ::testing::ElementsAreArray; -class EmbeddingLookupOpModel : public SingleOpModel { +class BaseEmbeddingLookupOpModel : public SingleOpModel { public: - EmbeddingLookupOpModel(std::initializer_list index_shape, - std::initializer_list weight_shape) { + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); - weight_ = AddInput(TensorType_FLOAT32); + weight_ = AddInput(weight_type); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); @@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; @@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel { } } } +}; - std::vector GetOutput() { return ExtractVector(output_); } +class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + HybridEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) + : BaseEmbeddingLookupOpModel(index_shape, weight_shape, + TensorType_UINT8) {} - private: - int input_; - int weight_; - int output_; + void SetWeight(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weight_, data); + } }; // TODO(ahentz): write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.PopulateTensor(0, {1, 0, 2}); + m.SetInput({1, 0, 2}); m.Set3DWeightMatrix( [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); @@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 8}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 989920622dff1fe246efb920e0d18efa5f8e9215..f6fc0f5b6ad12d58c541efc6eae566ab4b8327f4 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -105,7 +105,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]); if (bias) { TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } @@ -118,8 +118,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 0a5223b23529ef80b251d5144a94c5969c5cc02c..7962fcbc9d6c839ea11d7355e955239194442e03 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -176,6 +176,40 @@ cc_library( }), ) +cc_library( + name = "legacy_optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/depthwiseconv_uint8_3x3_filter.h", + "optimized/legacy_optimized_ops.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":quantization_util", + ":strided_slice_logic", + ":types", + ":legacy_reference_base", + ":round", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "optimized", hdrs = [ @@ -273,6 +307,37 @@ cc_library( }), ) +cc_library( + name = "legacy_reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/legacy_reference_ops.h", + "reference/reference_ops.h", + ], + deps = [ + ":quantization_util", + ":round", + ":strided_slice_logic", + ":types", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "reference", hdrs = ["tensor.h"], @@ -474,8 +539,9 @@ cc_test( ) cc_test( - name = "resize_bilinear_float_test", - srcs = ["resize_bilinear_float_test.cc"], + name = "resize_bilinear_test", + srcs = ["resize_bilinear_test.cc"], + tags = ["tflite_not_portable"], deps = [ ":optimized_base", ":reference_base", diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 6e621839753b5dacd96f2615b47e878dbe1de683..36c25388e8bde721d7644dc83d5b7c490d37b4d3 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -350,7 +350,7 @@ void LstmStep( for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; + scaling_factors[b] * input_to_output_weights_scale; } tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, @@ -409,7 +409,7 @@ void LstmStep( } // Save quantization and matmul computation for all zero input. - const bool is_cell_state_all_zeros = + bool is_cell_state_all_zeros = tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); // For each batch and cell: update input gate. @@ -455,6 +455,8 @@ void LstmStep( params->cell_clip, cell_state_ptr); } + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); // For each batch and cell: update the output gate. if (use_peephole && !is_cell_state_all_zeros) { VectorMultiply(cell_to_output_weights_ptr, n_cell, diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc index b7531ea2e202cd6fe012e0fa675380775016d38f..e786f785abe3aa66a9fb243dd4f332ca91676863 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -116,10 +116,11 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, int32 reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; - tflite::PreprocessLogSoftmaxScaling( + tflite::PreprocessLogSoftmaxScalingExp( beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, &input_beta_left_shift, &reverse_scaling_divisor, &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; // diff_min has a negative value, and is used to limit the maximum magnitude // of the diffs, which are <= 0. const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index a7b0d805a3acd35b592a35ba4266dfff4eb992cd..4cfaa0f36defa9c1f7d4a51af243c416bf09e331 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -26,7 +26,7 @@ namespace optimized_ops { // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Jetson TX-2. This compiler does not support the offsetof() macro. #if defined(__aarch64__) && !defined(GOOGLE_L4T) - +#include // clang-format gets confused with this file and ends up formatting lines to // be larger than 80 characters. Turn off here and back on at the end of the // file. diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..c0dda4acf1a59de22d07905f2a2e2bbe422e8d21 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + return L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +} // namespace optimized_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 0ce781db59a2cff0e0c199244b867fddf98804d6..d0008cc4fb62c2105d6817a6e44cefa974a31dbc 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1082,10 +1082,10 @@ struct GemmlowpOutputPipeline { gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; - static Pipeline Make(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_left_shift, int32 output_activation_min, + int32 output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -1093,7 +1093,7 @@ struct GemmlowpOutputPipeline { quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = output_shift; + quantize_down_stage.result_shift = -output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -1146,8 +1146,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, batches, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -1821,8 +1821,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, // Use dimensions M and N to construct dims for indexing directly into im2col Dims<4> im2col_dims; - im2col_dims.sizes[0] = col_dims.strides[3]; - im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); im2col_dims.sizes[2] = 1; im2col_dims.sizes[3] = 1; ComputeStrides(&im2col_dims); @@ -1831,8 +1831,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - // Each row is an output pixel. Arrange the input data into this row in - // an order we can conveniently multiply with the filter data. + // Each im2col row is an output pixel. Arrange the input data in this + // row in an order we can conveniently multiply with the filter data. int row_offset = Offset(row_dims, out_x, out_y, batch, 0); const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; @@ -1848,7 +1848,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); if ((in_x >= 0) && (in_x < input_width)) { - // Filter pixel is within the input, copy the data. + // Filter pixel is within the input, copy the input data. T const* src = input_data + Offset(input_dims, 0, in_x, in_y, batch); memcpy(dst, src, input_depth * sizeof(T)); @@ -1858,7 +1858,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, } } } else { - // Filter row is outside the input, zero out the entire im2col row. + // Filter row is outside the input, zero out the entire filter row. int col_offset = Offset(col_dims, 0, 0, filter_y, 0); T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); @@ -1922,7 +1922,7 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); - // A float set to 0x00000000h == 0.0f + // NB: static_cast(0x00000000h) == 0.0f const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; @@ -2084,8 +2084,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, gemm_input_data, gemm_input_rows, gemm_input_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2242,8 +2242,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, output_cols, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2366,12 +2366,15 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -2387,8 +2390,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -2430,31 +2434,35 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { + // Note that input_data advances by depth in the second pass below. int32 diff = input_data[c] - input_zero_point; square_l2_norm += diff * diff; } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); for (int c = 0; c < depth; c++) { int32 diff = *input_data - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( - 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); *output_data = static_cast(output_val); @@ -5722,6 +5730,46 @@ inline void ResizeBilinearGeneric(const float* input_data, } } +template +inline void ResizeBilinearGenericSmallChannel( + const T* input_data, const Dims<4>& input_dims, T* output_data, + const Dims<4>& output_dims, int32 batches, int32 input_height, + int32 input_width, int32 depth, int32 output_height, int32 output_width, + float height_scale, float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(T)); + + T* output_ptr = &output_data[0]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + + int32 input_offset[4] = { + Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), + Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), + (1 - (input_y - y0)) * (input_x - x0), + (input_y - y0) * (1 - (input_x - x0)), + (input_y - y0) * (input_x - x0)}; + + for (int d = 0; d < depth; d++) { + const T* input_ptr = &input_data[d]; + *output_ptr++ = static_cast(input_ptr[input_offset[0]] * scale[0] + + input_ptr[input_offset[1]] * scale[1] + + input_ptr[input_offset[2]] * scale[2] + + input_ptr[input_offset[3]] * scale[3]); + } + } + } + } +} + inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, @@ -5762,6 +5810,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 +// or int16 arithmetic. +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + float height_scale = + (align_corners && output_height > 1) + ? (static_cast(input_height - 1) / (output_height - 1)) + : (static_cast(input_height) / output_height); + + float width_scale = + (align_corners && output_width > 1) + ? (static_cast(input_width - 1) / (output_width - 1)) + : (static_cast(input_width) / output_width); + + ResizeBilinearGenericSmallChannel( + input_data, input_dims, output_data, output_dims, batches, input_height, + input_width, depth, output_height, output_width, height_scale, + width_scale); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -5771,6 +5854,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -6279,69 +6371,84 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output, } } -inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TransposeConv"); - // THIS FUNCTION IS A COPY FROM reference_ops.h. - // To optimize, start by using the conv code with transposed weights for the - // case of stride_height = stride_width = 1. +template +void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + gemmlowp::ScopedProfilingLabel label("TransposeIm2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); const int filter_height = ArraySize(filter_dims, 2); const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth - // Although transpose convolution simplifies to convolution with transposed - // weights for strides of 1, non-unitary striding complicates matters. To - // keep this reference implementation as clear as possible, we use a "scatter" - // access pattern, where we loop through all the input elements, computing - // their influence on the output, rather than looping through the output - // elements in the typical "gather" access pattern of a conv. We therefore - // must initialize the output array to zero. - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - 0.0f; - } - } - } - } + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); - // Loop through input elements one at a time. + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Build the im2col matrix by looping through all the input pixels, + // computing their influence on the output, rather than looping through all + // the output pixels. We therefore must initialize the im2col array to zero. + // This is potentially inefficient because we subsequently overwrite bytes + // set here. However, in practice memset is very fast and costs negligible. + memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T)); + + // Loop through the output batches for (int batch = 0; batch < batches; ++batch) { + // Loop through input pixels one at a time. for (int in_y = 0; in_y < input_height; ++in_y) { for (int in_x = 0; in_x < input_width; ++in_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - // Loop through the output elements it will influence - const int out_x_origin = (in_x * stride_width) - pad_width; - const int out_y_origin = (in_y * stride_height) - pad_height; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + // Loop through the output pixels it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int out_y = out_y_origin + filter_y; + // Is output pixel within height bounds? + if ((out_y >= 0) && (out_y < output_height)) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int out_channel = 0; out_channel < output_depth; - ++out_channel) { - // Compute output element location - const int out_x = out_x_origin + filter_x; - const int out_y = out_y_origin + filter_y; - // We cannot accumulate out of bounds - if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && - (out_y < output_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; - output_data[Offset(output_dims, out_channel, out_x, out_y, - batch)] += input_value * filter_value; - } + const int out_x = out_x_origin + filter_x; + // Is output pixel within width bounds? + if ((out_x >= 0) && (out_x < output_width)) { + // Copy the input elements of this pixel + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + T* dst = im2col_data + + Offset(im2col_dims, + Offset(col_dims, 0, filter_x, filter_y, 0), + Offset(row_dims, out_x, out_y, batch, 0), 0, 0); + memcpy(dst, src, input_depth * sizeof(T)); } } } @@ -6351,6 +6458,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("TransposeConv"); + + // Note we could use transposed weights with forward conv for unstrided + // cases. But we are already getting good performance with this code as-is. + TFLITE_DCHECK(im2col_data); + TransposeIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, pad_width, pad_height, output_dims, 0, + im2col_data); + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index b0951aac8cbb98a181d9dcaef88770fadfc74f62..57ee859115cddbcbccae24ff639e848340d8e2ee 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -48,15 +48,15 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, TFLITE_CHECK_GE(*left_shift, 0); } -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { TFLITE_CHECK_LT(double_multiplier, 1.); TFLITE_CHECK_GT(double_multiplier, 0.); int shift; QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); TFLITE_CHECK_LE(shift, 0); - *right_shift = -shift; + *left_shift = shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, @@ -78,20 +78,21 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, quantized_multiplier, left_shift); } -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift) { +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift) { PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits, quantized_multiplier, left_shift); // Also calculate what amounts to the inverse scaling factor for the input. const double real_reverse_scaling_divisor = (1 << (31 - *left_shift)) / static_cast(*quantized_multiplier); - tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor, - reverse_scaling_divisor, - reverse_scaling_right_shift); + tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor, + reverse_scaling_divisor, + reverse_scaling_left_shift); } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 4a217515f142b2451ebd61e423871b95cdc09748..182ee782c76fcccedc99327d47805b49bfb8580d 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -167,9 +167,9 @@ IntOut SafeCast(FloatIn x) { // this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift); +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of its exponent. @@ -197,11 +197,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift); // Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated. -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift); +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift); // Calculate the largest input that will result in a within-bounds intermediate // result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, // it must not overflow before we reduce the value by multiplication by the diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 2d74b3d3849812a2dc95fabcd680aa280c99ca55..94773b47d3817d7ed7240f74545ad04e7fa4bd52 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -196,21 +196,21 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); } -TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) { auto quantize = [](double d) { int32_t q; int s; - QuantizeMultiplierSmallerThanOne(d, &q, &s); + QuantizeMultiplierSmallerThanOneExp(d, &q, &s); return std::pair{q, s}; }; EXPECT_DEATH(quantize(-0.1), ""); EXPECT_DEATH(quantize(0.0), ""); - EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, -1)); // Around 0.5 we can see the change in exponent and how we try hard to // void hitting max int32. - EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1)); EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..6f5f6a3e6fa905f594c0361b163b5b817306dafc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + return L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index af7bc8b91ae08323299a0eda6a8b7720bd7310ae..6cef94a606aa6fbc39f3105b9b7aca1af4092970 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -950,11 +950,14 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -968,8 +971,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -1011,34 +1015,36 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; square_l2_norm += diff * diff; } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( - 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, c, i, 0, 0)] = - static_cast(output_val); + output_data[depth * i + c] = static_cast(output_val); } } } @@ -3202,9 +3208,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3236,15 +3243,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 x0 = static_cast(std::floor(input_x)); int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { - float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * - (1 - (input_y - y0)) * - (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * - (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * - (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * - (input_y - y0) * (input_x - x0); + T interpolation = + static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0)); output_data[Offset(output_dims, c, x, y, b)] = interpolation; } } @@ -3257,8 +3264,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, - output_data, output_dims, /*align_corners=*/false); + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); +} + +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template @@ -3808,10 +3825,11 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, int stride_width, int stride_height, int pad_width, int pad_height, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, float* /*im2col_data*/, + const Dims<4>& /*im2col_dims*/) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -3851,8 +3869,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc similarity index 60% rename from tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc rename to tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc index c1c50dff4d2a966bff70853701334f599ee03849..3d8765f11b2941ef5871c7db8e3582e506713aa6 100644 --- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc @@ -24,9 +24,10 @@ limitations under the License. namespace tflite { namespace { +template void TestOneResizeBilinear(int batch, int depth, int input_width, int input_height, int output_width, - int output_height) { + int output_height, float error_threshold) { Dims<4> input_dims_inference = MakeDimsForInference(depth, input_width, input_height, batch); Dims<4> output_dims_inference = @@ -36,14 +37,15 @@ void TestOneResizeBilinear(int batch, int depth, int input_width, const int output_buffer_size = RequiredBufferSizeForDims(output_dims_inference); - std::vector input_data(input_buffer_size, 0); - std::vector reference_output_data(output_buffer_size, 0); + std::vector input_data(input_buffer_size, 0); + std::vector reference_output_data(output_buffer_size, 0); // Initialize the output data with something other than zero, so we can catch // issue with kernels failing to initialize the output. - std::vector output_data(output_buffer_size, 3.1415); + std::vector output_data(output_buffer_size, 3); - const float input_amplitude = 1.f; - FillRandom(&input_data, -input_amplitude, input_amplitude); + const T min_amplitude = static_cast(0); + const T max_amplitude = static_cast(255); + FillRandom(&input_data, min_amplitude, max_amplitude); Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1); std::vector output_size_data = {output_height, output_width}; @@ -58,14 +60,46 @@ void TestOneResizeBilinear(int batch, int depth, int input_width, double sum_diff = 0; float max_abs_val = 0; for (int i = 0; i < output_buffer_size; i++) { - sum_diff += std::abs(output_data[i] - reference_output_data[i]); - max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i])); + sum_diff += std::abs(static_cast(output_data[i]) - + static_cast(reference_output_data[i])); + max_abs_val = std::max( + max_abs_val, std::abs(static_cast(reference_output_data[i]))); } if (sum_diff != 0.f) { const float mean_diff = static_cast(sum_diff / output_buffer_size); const float relative_error = std::abs(mean_diff) / max_abs_val; - ASSERT_LT(relative_error, 1e-5f); + ASSERT_LT(relative_error, error_threshold); + } +} + +TEST(ResizeBilinear, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 0.025); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } @@ -79,8 +113,8 @@ TEST(ResizeBilinear, TestResizeBilinear) { const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); - TestOneResizeBilinear(batch, depth, input_width, input_height, output_width, - output_height); + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } @@ -94,8 +128,8 @@ TEST(ResizeBilinear2x2, TestResizeBilinear) { const int output_width = input_width * 2; const int output_height = input_height * 2; - TestOneResizeBilinear(batch, depth, input_width, input_height, output_width, - output_height); + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } } // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ce887cea8b794b4b0cfd31722581cf9327be625e..518bee1c6369d3ce93d1b98e19dba7615b5844dc 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -34,6 +34,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline int32_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -62,6 +67,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -114,6 +124,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +inline RuntimeShape GetTensorShape(std::vector data) { + return RuntimeShape(data.size(), data.data()); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + auto* dims = tensor->dims; + return RuntimeShape(dims->size, dims->data); +} + // A list of tensors in a format that can be used by kernels like split and // concatenation. template diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 0c7fb7a76a5075652e705e65f5379596dfa77c78..64f4881a4686525fa6b56c30c1411fe5c91334b2 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -65,6 +65,10 @@ class RuntimeShape { ReplaceWith(dimensions_count, dims_data); } + RuntimeShape(const std::initializer_list init_list) : size_(0) { + BuildFrom(init_list); + } + ~RuntimeShape() { if (size_ > kMaxSmallSize) { delete[] dims_pointer_; @@ -121,6 +125,10 @@ class RuntimeShape { } } + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + // Returns the total count of elements, that is the size when flattened into a // vector. inline int FlatSize() const { @@ -142,6 +150,22 @@ class RuntimeShape { }; }; +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); @@ -194,6 +218,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, return offset; } +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0)); + TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1)); + TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2)); + TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3)); + const int* dims_data = shape.DimsData(); + return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -208,6 +241,9 @@ inline int Offset(const Dims<4>& dims, int* index) { } // Get array size, DCHECKing that the dim index is in range. +// +// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims() +// already performs this check. template int ArraySize(const Dims& array, int index) { TFLITE_DCHECK(index >= 0 && index < N); @@ -229,6 +265,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1, return MatchingArraySize(array1, index1, args...); } +// Get common shape dim, DCHECKing that they all agree. +inline int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return shape1.Dims(index1); +} + +template +int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2, Args... args) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return MatchingDim(shape1, index1, args...); +} + +// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize(). template inline int FlatSize(const Dims& dims) { int flat_size = 1; @@ -348,6 +399,72 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, check_dims_3); } +// Data is required to be contiguous, and so many operators can use either the +// full array flat size or the flat size with one dimension skipped (commonly +// the depth). +inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) { + const int dims_count = shape.DimensionsCount(); + TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count); + const auto* dims_data = shape.DimsData(); + int flat_size = 1; + for (int i = 0; i < dims_count; ++i) { + flat_size *= (i == skip_dim) ? 1 : dims_data[i]; + } + return flat_size; +} + +// A combination of MatchingFlatSize() and FlatSizeSkipDim(). +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return FlatSizeSkipDim(shape, skip_dim); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2, + check_shape_3); +} + template bool IsPackedWithoutStrides(const Dims& dims) { int expected_stride = 1; diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 3205c1cc52724207904621a5870636841ef379fe..a7b54c6b842332feb2d9e7179e79ae054bd23bb9 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { #define TF_LITE_L2NORM(type) \ type::L2Normalization( \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), GetTensorDims(output)) + GetTensorData(input), GetTensorShape(input), \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); @@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else if (output->type == kTfLiteUInt8) { -#define TF_LITE_L2NORM(type) \ - type::L2Normalization(GetTensorData(input), GetTensorDims(input), \ - input->params.zero_point, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_L2NORM(type) \ + type::L2Normalization(GetTensorData(input), GetTensorShape(input), \ + input->params.zero_point, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 62f4e94a386fbbc6987e8a6dc1a9a47ce3349cbb..b69a221447db963bcd3a7e6a69f132fe3767bfd1 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -120,8 +120,9 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, double real_multiplier = input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, + &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 6c68bb2f31003e08585b2fa3df0efe6d291ddb36..7bb28d4de7402a45954691a2e031e3b6b7433ffb 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -73,6 +73,7 @@ TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -150,6 +151,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index b928f1b302580d52f708bbf85dfcfc0f79ff1e69..940718d67e70b7206227b891ea529cb9e9619161 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -32,4 +32,4 @@ class BuiltinOpResolver : public MutableOpResolver { } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index f2092eaa36db32ebbc959ac23365bb13dd034e68..86c4cd3ee88013ca4174f444d0388bc036d9cde6 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); // ResizeBilinear creates a float tensor even when the input is made of // integers. - output->type = kTfLiteFloat32; + output->type = input->type; if (!IsConstantTensor(size)) { SetTensorToDynamic(output); @@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output), \ +#define TF_LITE_RESIZE_BILINEAR(type, datatype) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ params->align_corners) if (kernel_type == kReference) { - TF_LITE_RESIZE_BILINEAR(reference_ops); + TF_LITE_RESIZE_BILINEAR(reference_ops, float); } if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { - TF_LITE_RESIZE_BILINEAR(optimized_ops); + TF_LITE_RESIZE_BILINEAR(optimized_ops, float); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } #undef TF_LITE_RESIZE_BILINEAR } else { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 4e03f3820a5c14ee1692c553db61e385716b1723..10caffea03ebcec7862df1627541ac3d076b04e4 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using uint8 = std::uint8_t; class ResizeBilinearOpModel : public SingleOpModel { public: @@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel { } else { size_ = AddInput({TensorType_INT32, {2}}); } - output_ = AddOutput(TensorType_FLOAT32); // Always float. + output_ = AddOutput(input.type); SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); @@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } private: int input_; @@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel { TEST(ResizeBilinearOpTest, HorizontalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); - m.SetInput({3, 6}); + m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); - const_m.SetInput({3, 6}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); - m.SetInput({3, 9}); + m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); - const_m.SetInput({3, 9}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12 // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12, // 4, 10, // @@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); - m.SetInput({ + m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); } +TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index bdcaab8e2fa8a3342e0958635ec77a15a7238ccf..a8b803589962032db3ed579d31e8b736c3afada0 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -126,16 +126,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 3c99661029ed1ac881536f83519dcec355c60d50..8b9deeed20d761876d526c07eb78b602ca7314dc 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -79,7 +79,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), - SizeOfDimension(weights, 0)); + SizeOfDimension(weights, 3)); if (!IsConstantTensor(output_shape)) { SetTensorToDynamic(output); @@ -119,10 +119,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { case kTfLiteFloat32: - optimized_ops::TransposeConv( + reference_ops::TransposeConv( GetTensorData(input), GetTensorDims(input), GetTensorData(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, + GetTensorData(output), GetTensorDims(output), + // Last two args specify im2col which reference_ops ignores. + // (Note this does not lead to a performance regression, as the + // previous optimized version was just a copy of the reference code.) + // TODO(b/110208176): Allocate im2col tensors and switch to + // optimized_ops. GetTensorData(output), GetTensorDims(output)); break; default: diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc index 52be08934997f484337e4a3592bc7af832601695..55df8971806ed0baae9f5bcaebd24fb8065ec300 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc @@ -88,10 +88,10 @@ TEST(TransposeConvOpModelTest, SimpleTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) TEST(TransposeConvOpModelTest, TwoFiltersTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1); m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -117,10 +117,10 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) TEST(TransposeConvOpModelTest, PaddingValidTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1); m.PopulateTensor(m.output_shape(), {1, 6, 6, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -171,10 +171,10 @@ TEST(TransposeConvOpModelTest, StrideValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST(TransposeConvOpModelTest, MultiChannelTest) { - TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2); + TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2); m.PopulateTensor(m.output_shape(), {1, 5, 5, 2}); - m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18}); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); m.PopulateTensor(m.input(), {1, 2, 3, 4}); m.Invoke(); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index d78b6eae90f17a1c6775ba43647ae67720038207..bc62e4cc2d8af9b1c242900a9730f4fae3b92a6c 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_FLOAT32: *type = kTfLiteFloat32; break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; case TensorType_INT32: *type = kTfLiteInt32; break; @@ -322,12 +325,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { - case BuiltinOperator_CALL: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - break; - case BuiltinOperator_CUSTOM: - break; case BuiltinOperator_CONV_2D: { TfLiteConvParams* params = MallocPOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { @@ -343,21 +340,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU6: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_EXP: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_PRELU: - case BuiltinOperator_FLOOR: - case BuiltinOperator_NEG: - case BuiltinOperator_SIN: - break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { @@ -445,9 +427,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_EMBEDDING_LOOKUP: - // no-op. - break; case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = MallocPOD(); @@ -579,12 +558,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_PAD: { - break; - } - case BuiltinOperator_PADV2: { - break; - } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -624,15 +597,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_SPACE_TO_BATCH_ND: { - break; - } - case BuiltinOperator_BATCH_TO_SPACE_ND: { - break; - } - case BuiltinOperator_TRANSPOSE: { - break; - } case BuiltinOperator_MEAN: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_MeanOptions()) { @@ -672,10 +636,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: { - break; - } case BuiltinOperator_ARG_MAX: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { @@ -685,18 +645,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_EQUAL: - case BuiltinOperator_NOT_EQUAL: - case BuiltinOperator_SELECT: { - break; - } - case BuiltinOperator_SLICE: { - break; - } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -724,10 +672,46 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: case BuiltinOperator_EXPAND_DIMS: - case BuiltinOperator_TILE: { + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: break; - } } return kTfLiteOk; } @@ -868,7 +852,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const char* buffer_ptr; TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + bool is_variable = tensor->is_variable(); if (buffer_ptr) { + if (is_variable) { + error_reporter_->Report( + "Tensor %d is a variable tensor with buffer. " + "It's not supported now.\n", + i); + status = kTfLiteError; + } + if (interpreter->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { @@ -877,8 +870,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite( - i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -962,6 +956,15 @@ TfLiteStatus InterpreterBuilder::operator()( if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); + std::vector variables; + for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { + auto* tensor = (*interpreter)->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } + } + (**interpreter).SetVariables(variables); + return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 605ce7d6fc1e4375409070f710e20de0c3e1352f..999c31d4bff9279810a3661f0bb342cc4ef3ddaa 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -234,7 +234,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, next_id++; }; - auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + auto add_add_params = [&add_scalar_int32](void* data) { + auto* builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; auto add_pooling_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); @@ -345,11 +348,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); @@ -490,6 +493,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SELECT: case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: + case tflite::BuiltinOperator_LOG: case tflite::BuiltinOperator_TRANSPOSE_CONV: case tflite::BuiltinOperator_TILE: case tflite::BuiltinOperator_EXPAND_DIMS: diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index dfdd80ea8a42af683632be1d7e8ab0062847077d..3af809a2a1034c411881bfc6a919562d326e99cf 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -50,6 +50,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteString"; case kTfLiteBool: return "kTfLiteBool"; + case kTfLiteInt16: + return "kTfLiteInt16"; } return "(invalid)"; } diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index c31189f2b1f1ad6e3d8e2f5fcae9b6c2ef8eaf23..a162b87b8f98576ec7c3b2623d1d34f2baef6cce 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + common_copts = [ "-Wall", -] +] + tflite_copts() cc_library( name = "profiler", @@ -36,12 +38,14 @@ cc_library( name = "time", srcs = ["time.cc"], hdrs = ["time.h"], + copts = common_copts, ) cc_library( name = "profile_summarizer", srcs = ["profile_summarizer.cc"], hdrs = ["profile_summarizer.h"], + copts = common_copts, deps = [ ":profiler", "//tensorflow/contrib/lite:framework", @@ -53,6 +57,7 @@ cc_library( cc_test( name = "profile_summarizer_test", srcs = ["profile_summarizer_test.cc"], + copts = common_copts, deps = [ ":profile_summarizer", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 6f2c9cd2b39a1d6be77a10b18658665874067d87..45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -85,11 +85,18 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, return details; } +tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() { + auto options = tensorflow::StatSummarizerOptions(); + options.show_summary = true; + options.show_memory = false; + return options; +} + } // namespace ProfileSummarizer::ProfileSummarizer() - : stats_calculator_(new ::tensorflow::StatsCalculator( - tensorflow::StatSummarizerOptions())) {} + : stats_calculator_( + new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {} void ProfileSummarizer::ProcessProfiles( const std::vector& profile_stats, diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 7e6ff6c0a8314e71a64f27916a6189f229b81ab4..27909a9458f6b09f96cb556a5254f01e54f46e05 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -57,8 +57,9 @@ py_library( ":interpreter", ":lite_constants", ":op_hint", - "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", "//tensorflow/python/tools:freeze_graph_lib", ], ) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 08f3f8bf32981f2ef0c66f0ce312b28e9d90b260..c038c88945b71f30bf091a1098dcf853f5415b1b 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -111,27 +111,27 @@ def tensor_name(x): return x.name.split(":")[0] -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=lite_constants.FLOAT, - inference_input_type=None, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - default_ranges_stats=None, - drop_control_dependency=True, - reorder_across_fake_quant=False, - allow_custom_ops=False, - change_concat_input_ranges=False, - quantize_weights=False): - """Convert a model using TOCO from `input_format` to `output_format`. +def build_toco_convert_protos(input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + inference_input_type=None, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + default_ranges_stats=None, + drop_control_dependency=True, + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False, + quantize_weights=False, + dump_graphviz_dir=None, + dump_graphviz_video=False): + """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which case the default `input_format` and `output_format` are sufficient. Args: - input_data: Input data (i.e. often `sess.graph_def`). input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -170,10 +170,16 @@ def toco_convert(input_data, weights followed by dequantize operations. Computation is still done in float, but reduces model size (at the cost of accuracy and latency). (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Returns: - The converted data. For example if TFLite was the destination, then - this will be a tflite flatbuffer in a bytes array. + model_flags, toco_flags: two protocol buffers describing the conversion + process. Raises: ValueError: If the input tensor type is unknown @@ -193,7 +199,9 @@ def toco_convert(input_data, if default_ranges_stats: toco.default_ranges_min = default_ranges_stats[0] toco.default_ranges_max = default_ranges_stats[1] - + if dump_graphviz_dir: + toco.dump_graphviz_dir = dump_graphviz_dir + toco.dump_graphviz_include_video = dump_graphviz_video model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): @@ -222,10 +230,35 @@ def toco_convert(input_data, for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) + return model, toco + + +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). + + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. + + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + Defined in `build_toco_convert_protos`. + """ + model_flags, toco_flags = build_toco_convert_protos(input_tensors, + output_tensors, + *args, **kwargs) + data = toco_convert_protos(model_flags.SerializeToString(), + toco_flags.SerializeToString(), input_data.SerializeToString()) return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index 5dad49f1ed29f3bd57b1b120808ef645adee760c..1553464b9fe30f596c151bcc67efe891bb913ba3 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -19,13 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.lite.python.convert import tensor_name -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader @@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key): raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " "values are '{}'.".format(signature_key, ",".join(signature_def_keys))) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. input_arrays or output_arrays are not valid. """ @@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5f304ad45d400b13e20bda8184b5b40cfe13f6c2..e5e5c4fb029d6964fa0f26ae632a2b8e912d1cab 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -68,6 +68,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteInt32: return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; case kTfLiteUInt8: return NPY_UINT8; case kTfLiteInt64: @@ -90,6 +92,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteFloat32; case NPY_INT32: return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; case NPY_UINT8: return kTfLiteUInt8; case NPY_INT64: diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 01320af7a9ea3a652020e2b42300da6081ff68e5..c02aa3804367f787016ef78fc8557005507f051b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +// Place `` before to avoid build failures in macOS. +#include #include // We forward declare TFLite classes here to avoid exposing them to SWIG. diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 477faabdb90a6652495400f23a7d98fdc8aa5169..8315066cd129a137b9159690123ae47bee18c1c8 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -22,6 +22,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@build_toco_convert_protos @@FLOAT @@QUANTIZED_UINT8 @@ -38,6 +39,7 @@ from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants +from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import @@ -98,6 +100,12 @@ class TocoConverter(object): weights followed by dequantize operations. Computation is still done in float, but reduces model size (at the cost of accuracy and latency). (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Example usage: @@ -140,6 +148,8 @@ class TocoConverter(object): self.change_concat_input_ranges = False self.allow_custom_ops = False self.quantize_weights = False + self.dump_graphviz_dir = None + self.dump_graphviz_video = False @classmethod def from_session(cls, sess, input_tensors, output_tensors): @@ -215,7 +225,7 @@ class TocoConverter(object): # Check if graph is frozen. if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py") + raise ValueError("Please freeze the graph using freeze_graph.py.") # Create TocoConverter class. return cls(sess.graph_def, input_tensors, output_tensors) @@ -316,7 +326,9 @@ class TocoConverter(object): reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, allow_custom_ops=self.allow_custom_ops, - quantize_weights=self.quantize_weights) + quantize_weights=self.quantize_weights, + dump_graphviz_dir=self.dump_graphviz_dir, + dump_graphviz_video=self.dump_graphviz_video) return result def get_input_arrays(self): diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index bbb00021f95b311f28124ef1bb1eb463b4985d80..8c9d2c1651dd2d0b3cd27cf638c04429e3131efb 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -220,6 +220,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # TODO(nupurgarg): Verify value of contents in GraphViz. def testGraphviz(self): in_tensor = array_ops.placeholder( shape=[1, 16, 16, 3], dtype=dtypes.float32) @@ -232,6 +233,39 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + def testInferenceInputType(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) out_tensor = in_tensor + in_tensor @@ -401,7 +435,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as error: lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], ['add']) - self.assertEqual('Please freeze the graph using freeze_graph.py', + self.assertEqual('Please freeze the graph using freeze_graph.py.', str(error.exception)) def testPbtxt(self): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 2b7ad29a27c121829a3fbcb41bfdec161ee6a6c9..f497533bed054d260aefc7b3fe67ae655c7cbcda 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -114,9 +114,10 @@ def _convert_model(flags): "--input_arrays must be present when specifying " "--std_dev_values and --mean_values with multiple input " "tensors in order to map between names and " - "values".format(",".join(input_arrays))) + "values.".format(",".join(input_arrays))) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) - if flags.default_ranges_min and flags.default_ranges_max: + if (flags.default_ranges_min is not None) and (flags.default_ranges_max is + not None): converter.default_ranges_stats = (flags.default_ranges_min, flags.default_ranges_max) @@ -130,6 +131,10 @@ def _convert_model(flags): converter.allow_custom_ops = flags.allow_custom_ops if flags.quantize_weights: converter.quantize_weights = flags.quantize_weights + if flags.dump_graphviz_dir: + converter.dump_graphviz_dir = flags.dump_graphviz_dir + if flags.dump_graphviz_video: + converter.dump_graphviz_vode = flags.dump_graphviz_video # Convert model. output_data = converter.convert() @@ -161,8 +166,12 @@ def _check_flags(flags, unparsed): output = "" for flag in unparsed: output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--savedmodel_directory", + "--saved_model_dir") output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + output += _get_message_unparsed(flag, "--dump_graphviz", + "--dump_graphviz_dir") if output: raise ValueError(output) @@ -187,7 +196,7 @@ def _check_flags(flags, unparsed): raise ValueError("--std_dev_values, --mean_values must have the same " "number of items") - if bool(flags.default_ranges_min) != bool(flags.default_ranges_max): + if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): raise ValueError("--default_ranges_min and --default_ranges_max must be " "used together") @@ -219,17 +228,17 @@ def run_main(_): # Model format flags. parser.add_argument( "--output_format", - type=str, + type=str.upper, choices=["TFLITE", "GRAPHVIZ_DOT"], help="Output file format.") parser.add_argument( "--inference_type", - type=str, + type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], help="Target data type of arrays in the output file.") parser.add_argument( "--inference_input_type", - type=str, + type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], help=("Target data type of input arrays. Allows for a different type for " "input arrays in the case of quantization.")) @@ -322,6 +331,20 @@ def run_main(_): "provide these to the TensorFlow Lite runtime with a custom " "resolver. (default False)")) + # Logging flags. + parser.add_argument( + "--dump_graphviz_dir", + type=str, + help=("Full filepath of folder to dump the graphs at various stages of " + "processing GraphViz .dot files. Preferred over --output_format=" + "GRAPHVIZ_DOT in order to keep the requirements of the output " + "file.")) + parser.add_argument( + "--dump_graphviz_video", + action="store_true", + help=("Boolean indicating whether to dump the graph after every graph " + "transformation")) + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) try: _check_flags(tflite_flags, unparsed) diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index 64ab0a9fe2f01a732af91ed4052e44cf8c38f89b..9dc8daa227dd68ccde2efa4013ac4465a72e6bb0 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -39,7 +39,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index d12a96df1c70ddaa6ae11f1ee809662314db89b0..c7b955a1659cf65ed0e0233b8b75db60887de34c 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -34,6 +34,7 @@ enum TensorType : byte { INT64 = 4, STRING = 5, BOOL = 6, + INT16 = 7, } // Parameters for converting a quantized tensor back to float. Given a @@ -63,6 +64,8 @@ table Tensor { buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -150,6 +153,7 @@ enum BuiltinOperator : byte { EXPAND_DIMS = 70, EQUAL = 71, NOT_EQUAL = 72, + LOG = 73, } // Options for the builtin operators. @@ -519,6 +523,16 @@ table Operator { builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; } // The root type, defining a subgraph, which typically represents an entire diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 8ddd2f14388cb78e23abf98b31485c254aad3e5c..81d4574da7f5025c4dd246b5fc8fe74b7d8b15ae 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -216,11 +216,12 @@ enum TensorType { TensorType_INT64 = 4, TensorType_STRING = 5, TensorType_BOOL = 6, + TensorType_INT16 = 7, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_BOOL + TensorType_MAX = TensorType_INT16 }; -inline TensorType (&EnumValuesTensorType())[7] { +inline TensorType (&EnumValuesTensorType())[8] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -228,7 +229,8 @@ inline TensorType (&EnumValuesTensorType())[7] { TensorType_UINT8, TensorType_INT64, TensorType_STRING, - TensorType_BOOL + TensorType_BOOL, + TensorType_INT16 }; return values; } @@ -242,6 +244,7 @@ inline const char **EnumNamesTensorType() { "INT64", "STRING", "BOOL", + "INT16", nullptr }; return names; @@ -325,11 +328,12 @@ enum BuiltinOperator { BuiltinOperator_EXPAND_DIMS = 70, BuiltinOperator_EQUAL = 71, BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_NOT_EQUAL + BuiltinOperator_MAX = BuiltinOperator_LOG }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[72] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -402,7 +406,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[72] { BuiltinOperator_TILE, BuiltinOperator_EXPAND_DIMS, BuiltinOperator_EQUAL, - BuiltinOperator_NOT_EQUAL + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG }; return values; } @@ -482,6 +487,7 @@ inline const char **EnumNamesBuiltinOperator() { "EXPAND_DIMS", "EQUAL", "NOT_EQUAL", + "LOG", nullptr }; return names; @@ -1668,9 +1674,11 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; + bool is_variable; TensorT() : type(TensorType_FLOAT32), - buffer(0) { + buffer(0), + is_variable(false) { } }; @@ -1681,7 +1689,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TYPE = 6, VT_BUFFER = 8, VT_NAME = 10, - VT_QUANTIZATION = 12 + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -1698,6 +1707,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const QuantizationParameters *quantization() const { return GetPointer(VT_QUANTIZATION); } + bool is_variable() const { + return GetField(VT_IS_VARIABLE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -1708,6 +1720,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && verifier.VerifyTable(quantization()) && + VerifyField(verifier, VT_IS_VARIABLE) && verifier.EndTable(); } TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1733,6 +1746,9 @@ struct TensorBuilder { void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } + void add_is_variable(bool is_variable) { + fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1751,12 +1767,14 @@ inline flatbuffers::Offset CreateTensor( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, flatbuffers::Offset name = 0, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { TensorBuilder builder_(_fbb); builder_.add_quantization(quantization); builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); } @@ -1767,14 +1785,16 @@ inline flatbuffers::Offset CreateTensorDirect( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { return tflite::CreateTensor( _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, name ? _fbb.CreateString(name) : 0, - quantization); + quantization, + is_variable); } flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5001,6 +5021,7 @@ struct OperatorT : public flatbuffers::NativeTable { BuiltinOptionsUnion builtin_options; std::vector custom_options; CustomOptionsFormat custom_options_format; + std::vector mutating_variable_inputs; OperatorT() : opcode_index(0), custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { @@ -5016,7 +5037,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_BUILTIN_OPTIONS_TYPE = 10, VT_BUILTIN_OPTIONS = 12, VT_CUSTOM_OPTIONS = 14, - VT_CUSTOM_OPTIONS_FORMAT = 16 + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18 }; uint32_t opcode_index() const { return GetField(VT_OPCODE_INDEX, 0); @@ -5202,6 +5224,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { CustomOptionsFormat custom_options_format() const { return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } + const flatbuffers::Vector *mutating_variable_inputs() const { + return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && @@ -5215,6 +5240,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.Verify(mutating_variable_inputs()) && verifier.EndTable(); } OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5462,6 +5489,9 @@ struct OperatorBuilder { void add_custom_options_format(CustomOptionsFormat custom_options_format) { fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } + void add_mutating_variable_inputs(flatbuffers::Offset> mutating_variable_inputs) { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5482,8 +5512,10 @@ inline flatbuffers::Offset CreateOperator( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + flatbuffers::Offset> mutating_variable_inputs = 0) { OperatorBuilder builder_(_fbb); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); builder_.add_outputs(outputs); @@ -5502,7 +5534,8 @@ inline flatbuffers::Offset CreateOperatorDirect( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + const std::vector *mutating_variable_inputs = nullptr) { return tflite::CreateOperator( _fbb, opcode_index, @@ -5511,7 +5544,8 @@ inline flatbuffers::Offset CreateOperatorDirect( builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, - custom_options_format); + custom_options_format, + mutating_variable_inputs ? _fbb.CreateVector(*mutating_variable_inputs) : 0); } flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5882,6 +5916,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t { auto _e = buffer(); _o->buffer = _e; }; { auto _e = name(); if (_e) _o->name = _e->str(); }; { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; + { auto _e = is_variable(); _o->is_variable = _e; }; } inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5897,13 +5932,15 @@ inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder & auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + auto _is_variable = _o->is_variable; return tflite::CreateTensor( _fbb, _shape, _type, _buffer, _name, - _quantization); + _quantization, + _is_variable); } inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7426,6 +7463,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; { auto _e = custom_options_format(); _o->custom_options_format = _e; }; + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } }; } inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -7443,6 +7481,7 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild auto _builtin_options = _o->builtin_options.Pack(_fbb); auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; + auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0; return tflite::CreateOperator( _fbb, _opcode_index, @@ -7451,7 +7490,8 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild _builtin_options_type, _builtin_options, _custom_options, - _custom_options_format); + _custom_options_format, + _mutating_variable_inputs); } inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 2f2004f56bcad5b56f9dd6d4bc824ec14d79e795..4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -36,6 +36,12 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, ArenaAlloc* new_alloc) { TF_LITE_ENSURE(context, alignment < arena_alignment_); + if (size == 0) { + new_alloc->offset = 0; + new_alloc->size = 0; + return kTfLiteOk; + } + size_t current_top = 0; if (!allocs_.empty()) { @@ -75,6 +81,10 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, const ArenaAlloc& alloc) { + if (alloc.size == 0) { + return kTfLiteOk; + } + int erased_allocs_count = 0; auto it = allocs_.begin(); while (it != allocs_.end()) { @@ -122,7 +132,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + if (alloc.size == 0) { + *output_ptr = nullptr; + } else { + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 5faf78b59e3755d22e4e866d433e622baa6c66c1..f738315cf2f91403f9dcb6fa9e66b40bd70495aa 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -39,7 +39,8 @@ struct ArenaAlloc { // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is -// repetitive, e.g. running NN inference in multiple iterations. +// repetitive, e.g. running NN inference in multiple iterations. Note that +// zero-sized allocations are explicitly allowed, and will resolve to null. class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment) diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc index 4444f642eb75c563c57762d095e454ac63d836c6..60d4d5e768aeda958574422e1c36a7cc2f6a1429 100644 --- a/tensorflow/contrib/lite/simple_memory_arena_test.cc +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -43,6 +43,47 @@ TEST(SimpleMemoryArenaTest, BasicArenaOperations) { EXPECT_EQ(allocs[5].offset, 1024); } +TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc alloc; + + // Zero-sized allocs should have a 0 offset and size. + ASSERT_EQ(arena.Allocate(&context, 32, 0, &alloc), kTfLiteOk); + EXPECT_EQ(alloc.offset, 0); + EXPECT_EQ(alloc.size, 0); + + // Deallocation of zero-sized allocs should always succeed (even redundantly). + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + + // The zero-sized alloc should resolve to null. + char* resolved_ptr = nullptr; + ASSERT_EQ(arena.Commit(&context), kTfLiteOk); + ASSERT_EQ(arena.ResolveAlloc(&context, alloc, &resolved_ptr), kTfLiteOk); + EXPECT_EQ(resolved_ptr, nullptr); +} + +TEST(SimpleMemoryArenaTest, InterleavedZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[4]; + + // Interleave some zero and non-zero-sized allocations and deallocations. + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[0]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 0, &allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 1023, &allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[3]), kTfLiteOk); + + // Deallocation of a zero-sized alloc should not impact the allocator offsets. + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 0); + EXPECT_EQ(allocs[2].offset, 2048); + EXPECT_EQ(allocs[3].offset, 2048); +} + TEST(SimpleMemoryArenaTest, TestAfterClear) { TfLiteContext context; SimpleMemoryArena arena(64); diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a89776b29f895fe82ee71efe00c0949c58c109df..a316a40b62d89189da43768d448acdf5bbeca129 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { dims->data[0] = offset_.size() - 1; // Store number of strings. TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, - tensor); + tensor->is_variable, tensor); } int GetStringCount(const char* raw_buffer) { diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 80e4c5a4dde4702229887593afc5ffeef339176d..b823c97f38e7660652aa0ce3538b11de59dc9aea 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -20,11 +20,15 @@ load( size = "large", srcs = ["generated_examples_zip_test.cc"], args = [ - "--zip_file_path=$(location :zip_%s)" % test_name, - # TODO(angerson) We may be able to add an external unzip binary instead - # of relying on an existing one for OSS builds. - "--unzip_binary_path=/usr/bin/unzip", - ], + ] + select({ + "//tensorflow:android": [], + "//conditions:default": [ + "--zip_file_path=$(location :zip_%s)" % test_name, + # TODO(angerson) We may be able to add an external unzip binary instead + # of relying on an existing one for OSS builds. + "--unzip_binary_path=/usr/bin/unzip", + ], + }), data = [ ":zip_%s" % test_name, ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 723b6ae057e2db5479784811d404b92d30cb0d14..f5e25784fa17209af7cfb06d32aeea2b9b947196 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2420,30 +2420,44 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_sin_tests(zip_path): - """Make a set of tests to do sin.""" +def _make_elementwise_tests(op): + """Make a set of tests to do element-wise operations.""" - test_parameters = [{ - "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], - }] + def f(zip_path): + """Actual function that generates examples.""" + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] - def build_graph(parameters): - """Build the sin op testing graph.""" - input_value = tf.placeholder( - dtype=parameters["input_dtype"], - name="input1", - shape=parameters["input_shape"]) - out = tf.sin(input_value) - return [input_value], [out] + def build_graph(parameters): + """Build the sin op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = op(input_value) + return [input_value], [out] - def build_inputs(parameters, sess, inputs, outputs): - input_value = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict={inputs[0]: input_value}) + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict={inputs[0]: input_value}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_sin_tests(zip_path): + """Make a set of tests to do sin.""" + return _make_elementwise_tests(tf.sin)(zip_path) + + +def make_log_tests(zip_path): + """Make a set of tests to do log.""" + return _make_elementwise_tests(tf.log)(zip_path) def make_where_tests(zip_path): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index e85020448a572650c6a70d8b4dcb4e73faf0f8c8..8a59d756f8dbbcefc930b5285c1ced8ce6b08845 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -36,7 +36,12 @@ bool FLAGS_ignore_known_bugs = true; // TODO(b/71769302) zip_files_dir should have a more accurate default, if // possible string* FLAGS_zip_file_path = new string("./"); +#ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +#else +string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +#endif +bool FLAGS_use_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -212,7 +217,7 @@ TEST_P(OpsTest, RunZipTests) { std::ifstream tflite_stream(tflite_test_case); ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case; - tflite::testing::TfLiteDriver test_driver(/*use_nnapi=*/true); + tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi); test_driver.SetModelBaseDir(tflite_dir); string bug_number; @@ -273,7 +278,10 @@ int main(int argc, char** argv) { "Required: Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary.")}; + "Required: Location of a suitable unzip binary."), + tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, + "Whether to enable the NNAPI delegate")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); @@ -281,6 +289,8 @@ int main(int argc, char** argv) { } ::tflite::LogToStderr(); + // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags + // parser removes them? ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index fc28faf52405b300dc6e4f0aab33122bb5e98f12..54edfdfb1df3f45b4823a36503c01551348ead6c 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -163,6 +163,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { Invalidate("Failed build interpreter"); return; } + interpreter_->UseNNAPI(use_nnapi_); must_allocate_tensors_ = true; } @@ -284,7 +285,9 @@ bool TfLiteDriver::CheckResults() { } void TfLiteDriver::ResetLSTMStateTensors() { - // This is a workaround for initializing state tensors for LSTM. + interpreter_->ResetVariableTensorsToZero(); + + // Below is a workaround for initializing state tensors for LSTM. // TODO(ycling): Refactoring and find a better way to initialize state // tensors. Maybe write the reset instructions into the test data. for (auto node_index : interpreter_->execution_plan()) { @@ -302,13 +305,6 @@ void TfLiteDriver::ResetLSTMStateTensors() { int node_index = node.outputs->data[i]; ResetTensor(node_index); } - } else if (params->kernel_type == kTfLiteLSTMBasicKernel && - node.inputs->size == 5) { - // The 2th and 5th inputs are state tensors. - for (int i : {1, 4}) { - int node_index = node.inputs->data[i]; - ResetTensor(node_index); - } } } } diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 7ea4f32ef694f3b0dc9c030b9440268ac79848aa..dd05c484fabf4d87dc12b39940a71677af4023e2 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -213,6 +213,7 @@ cc_library( "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", "graph_transformations/convert_trivial_stack_to_reshape.cc", + "graph_transformations/convert_trivial_tile_to_concat.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -224,6 +225,7 @@ cc_library( "graph_transformations/fuse_activation_functions.cc", "graph_transformations/fuse_binary_into_following_affine.cc", "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/fuse_broadcast_into_following_binary.cc", "graph_transformations/graph_transformations.cc", "graph_transformations/hardcode_min_max.cc", "graph_transformations/identify_dilated_conv.cc", @@ -293,7 +295,6 @@ cc_library( "graph_transformations/resolve_tensorflow_matmul.cc", "graph_transformations/resolve_tensorflow_merge.cc", "graph_transformations/resolve_tensorflow_switch.cc", - "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", "graph_transformations/unfuse_activation_functions.cc", "graph_transformations/unpartition_embedding_lookup.cc", @@ -374,6 +375,7 @@ tf_cc_test( ":toco_tooling", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", ], @@ -411,6 +413,7 @@ tf_cc_test( deps = [ ":model", ":tooling_util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 8913b5c3ea962725ef2bed73e670e8f0b988a591..878bda36ef3900d6d8c509aca40cee834cefe514 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -146,6 +146,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties node_properties; node_properties.color = GetColorForArray(model, array_name); node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + node_properties.log2_buffer_size = 0.0f; // Append array shape to the label. auto& array = model.GetArray(array_name); @@ -165,9 +166,12 @@ NodeProperties GetPropertiesForArray(const Model& model, } node_properties.label += "]"; - int buffer_size = RequiredBufferSizeForShape(array.shape()); - node_properties.log2_buffer_size = - std::log2(static_cast(buffer_size)); + int buffer_size = 0; + if (IsValid(array.shape())) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + node_properties.log2_buffer_size = + std::log2(static_cast(buffer_size)); + } if (array.buffer) { const auto& array = model.GetArray(array_name); @@ -200,8 +204,6 @@ NodeProperties GetPropertiesForArray(const Model& model, AppendF(&node_properties.label, "}"); } } - } else { - node_properties.log2_buffer_size = 0.0f; } if (array.minmax) { diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 76ce1c58025ce84219ed6a8a0b6f2ea6e18e037c..6e5e0d013750c8669f73003fb9ee861bb4aecb2f 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model, const auto& weights_array = model.GetArray(weights_array_name); CHECK(weights_array.buffer->type == ArrayDataType::kFloat); ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, - AxesOrder::kHWIO, tensorflow_graph); + AxesOrder::kHWOI, tensorflow_graph); auto& strides = (*conv2d_op->mutable_attr())["strides"]; strides.mutable_list()->add_i(1); strides.mutable_list()->add_i(src_op.stride_height); @@ -1687,6 +1687,22 @@ void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertTileOperator(const Model& model, + const TensorFlowTileOperator& src_op, + GraphDef* tensorflow_graph) { + auto* tile_op = tensorflow_graph->add_node(); + tile_op->set_op("Tile"); + tile_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *tile_op->add_input() = src_op.inputs[0]; + *tile_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*tile_op->mutable_attr())["T"].set_type(data_type); + const auto multiples_data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type); +} + void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { auto* topk_op = tensorflow_graph->add_node(); @@ -1953,6 +1969,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSelect) { ConvertSelectOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowTile) { + ConvertTileOperator(model, + static_cast(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ab399206ba5376e2ff7c5c7028a1ea3e9b92a03 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { + auto tile_it = model->operators.begin() + op_index; + if (tile_it->get()->type != OperatorType::kTensorFlowTile) { + return false; + } + auto* tile_op = static_cast(tile_it->get()); + + const auto& input_array = model->GetArray(tile_op->inputs[0]); + const auto& multiples_array = model->GetArray(tile_op->inputs[1]); + const auto& output_array = model->GetArray(tile_op->outputs[0]); + if (!input_array.has_shape() || !multiples_array.has_shape() || + !output_array.has_shape()) { + // Yield until PropagateFixedSizes has been run on this op. + return false; + } + // Note: We can assume we have error checked inputs in PropagateFixedSizes. + + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return false; + } + std::vector const& multiples = + multiples_array.GetBuffer().data; + + // We can simplify the tile if only a single dimension is being multiplied. + // It then just becomes a concat along that dimension. + int non_one_dims = 0; + int concat_axis = 0; + for (int i = 0; i < multiples.size(); ++i) { + if (multiples[i] != 1) { + ++non_one_dims; + concat_axis = i; + } + } + if (non_one_dims != 1) { + // The tile is non-trivial. Good luck. + AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", + LogName(*tile_op)); + return false; + } + + // The tile is like a concat. + AddMessageF("Simplifying %s to a Concat along a single axis %d", + LogName(*tile_op), concat_axis); + + auto* concat_op = new ConcatenationOperator; + + // Copy input and output. + // Note that we multiply out the input by the number of times requested. + for (int i = 0; i < multiples[concat_axis]; ++i) { + concat_op->inputs.push_back(tile_op->inputs[0]); + } + concat_op->axis = concat_axis; + concat_op->outputs = tile_op->outputs; + + // Delete multiples array if unused. + if (IsDiscardableArray(*model, tile_op->inputs[1]) && + CountOpsWithInput(*model, tile_op->inputs[1]) == 1) { + model->EraseArray(tile_op->inputs[1]); + } + + // Replace the operator in the graph. + const auto concat_it = model->operators.emplace(tile_it, concat_op); + tile_it = concat_it + 1; + CHECK_EQ(tile_it->get(), tile_op); + model->operators.erase(tile_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8c1039caa32e947fe54ab3e101bec9e..1e68cd678bce6c27f1852a5ae0c13362d8938cdd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -25,17 +25,12 @@ limitations under the License. namespace toco { -bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { - auto conv_it = model->operators.begin() + op_index; - if (conv_it->get()->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast(conv_it->get()); - if (conv_op->outputs.size() == 2) { +bool ProcessConvOperator(Model* model, ConvOperator* op) { + if (op->outputs.size() == 2) { // We already have an im2col array return false; } - const auto& weights_array = model->GetArray(conv_op->inputs[1]); + const auto& weights_array = model->GetArray(op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is @@ -45,25 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const auto& weights_shape = weights_array.shape(); const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); - if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + if (kwidth == 1 && kheight == 1 && op->stride_width == 1 && + op->stride_height == 1 && op->dilation_width_factor == 1 && + op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } // Create the im2col array. - CHECK_EQ(conv_op->outputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); const string& im2col_array_name = - AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + AvailableArrayName(*model, op->inputs[0] + "_im2col"); model->GetOrCreateArray(im2col_array_name); - conv_op->outputs.push_back(im2col_array_name); - AddMessageF( - "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " - "stride_height=%d", - LogName(*conv_op), kwidth, kheight, conv_op->stride_width, - conv_op->stride_height); + op->outputs.push_back(im2col_array_name); return true; } +bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { + if (op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + + // Always create an im2col array for transpose_conv. + CHECK_EQ(op->outputs.size(), 1); + const string& im2col_array_name = AvailableArrayName( + *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + op->outputs.push_back(im2col_array_name); + + return true; +} + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + switch (op->type) { + case OperatorType::kConv: + return ProcessConvOperator(model, static_cast(op)); + case OperatorType::kTransposeConv: + return ProcessTransposeConvOperator( + model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..874d8def571fbce4219de15285c8df6fd2487a9a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Returns true if the given op is strictly a broadcasting operation. +// This is commonly seen as a Concat of the same input multiple times, and is +// often generated from Tile ops that were converted via the +// convert_trivial_tile_to_concat transformation. +bool IsBroadcastingOp(const Model& model, Operator* op) { + // Concatenation of identical inputs is usually a broadcast. + if (op->type == OperatorType::kConcatenation) { + // Verify that all inputs are the same. + for (int i = 1; i < op->inputs.size(); ++i) { + if (op->inputs[i] != op->inputs[0]) { + return false; + } + } + return true; + } + + // There are other things we could look for (Stack/etc) when needed. + return false; +} + +} // namespace + +// Finds an operation that looks like a broadcast (concat of the same sources +// along the last dimension) and drops it by relying on the ability of certain +// binary ops to perform an implicit broadcast. +bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + // NOTE: either of these ops may be nullptr if the input array is constant. + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // Check whether either input is a broadcast-like concat. + bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]); + bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]); + if (!is_op_0_broadcast && !is_op_1_broadcast) { + // Neither input is a broadcast-looking thing. + AddMessageF("Neither input looks broadcasty"); + return false; + } else if (is_op_0_broadcast && is_op_1_broadcast) { + AddMessageF( + "Unable to fuse broadcast into %s as both inputs (%s, %s) are " + "broadcasts", + LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", + op[1] ? LogName(*op[1]) : "(?)"); + return false; + } + int broadcast_index = is_op_0_broadcast ? 0 : 1; + + // Just pull out the input of the broadcast op and pass it directly to the + // binary op. + AddMessageF("Fusing broadcast op %s into the following binary %s", + LogName(*op[broadcast_index]), LogName(*binary_op)); + binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; + + // We leave the broadcast op in; it'll get cleaned up if it's not used later. + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 1bc7557d46cfa5e1b27468d2da271e75fd491d36..62a09acdfbb553161e480051aa506486b9adec47 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -117,12 +117,14 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) @@ -165,7 +167,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) -DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d63ee7c9519d169a2f44ec1afe81125217db8976..bda6dce22be0f0ca83eb8339ad17573b0267c18c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -362,6 +362,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kResizeBilinear: + case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index ae3301f467de5714230e731b4bab87ddc1637201..d49857cfc22ecaf5feb06b39a42187f8adb61d50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -90,12 +90,13 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } // Conv Op - ConvOperator* conv_op = dynamic_cast( - has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0]) - : GetOpWithInput(*model, stb_op->outputs[0])); - if (!conv_op || conv_op->type != OperatorType::kConv) { + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + if (conv_base_op->type != OperatorType::kConv) { return false; } + auto* conv_op = static_cast(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 6d51fc8c31e6c86701c3dc1fd07a9a5479114738..77c08868117382f9daf900da79286e9f9e06d9db 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -103,6 +103,7 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { case OperatorType::kTensorFlowReshape: case OperatorType::kTranspose: case OperatorType::kSelect: + case OperatorType::kTensorFlowTile: // Reshapes and transposes don't change values. return false; default: @@ -124,6 +125,9 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { case OperatorType::kTranspose: // Ignore reshape/transpose shapes/dimensions. return input_index != 0; + case OperatorType::kTensorFlowTile: + // Ignore tile multiples. + return input_index != 0; default: return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9e4262223e416178e40f200e227ae6fa316a2728..e7da9051d835c30f93838b0c5be45dbcc92a70c1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -211,12 +211,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // might as well calculate the output shape and ensure it matches the // specified one - // Check if we have already run. - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - return; - } - // SPECIFIED OUTPUT SHAPE // The below is the specified, or prescribed output shape, _given_ to the // operator as an input. @@ -278,13 +272,23 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + CHECK_EQ(input_shape.dims(3), weights_shape.dims(3)) << "Input shape depth and weight depth do not agree"; // Set the output shape according to the specified output shape. std::vector const& specified_output_shape = specified_output_shape_array.GetBuffer().data; + auto& output_array = model->GetArray(op->outputs[0]); *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const int input_depth = weights_shape.dims(3); + auto& im2col_array = model->GetArray(op->outputs[1]); + im2col_array.copy_shape( + Shape{specified_output_shape[0], specified_output_shape[1], + specified_output_shape[2], input_depth * kheight * kwidth}); + } } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { @@ -1505,6 +1509,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { } } +void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& multiples_array = model->GetArray(op->inputs[1]); + if (!multiples_array.has_shape()) { + // Yield until multiples shape been resolved. + return; + } + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return; + } + CHECK(multiples_array.data_type == ArrayDataType::kInt32) + << "Tile multiples input must be int32"; + + std::vector const& multiples = + multiples_array.GetBuffer().data; + CHECK_EQ(multiples.size(), input_shape.dimensions_count()) + << "Tile multiples input " << op->inputs[1] + << " must be same length as input dimensions"; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(multiples.size()); + for (int i = 0; i < mutable_dims->size(); ++i) { + (*mutable_dims)[i] = input_shape.dims(i) * multiples[i]; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1623,14 +1669,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowTile: - // We don't currently implement the propagation of fixed sizes through - // a TensorFlow Tile. - // - // Fortunately, we don't need to: so far, we have only dealt with Tile - // or Slice ops in subgraphs that are identified as L2Normalization. - // See IdentifyL2Normalization. - break; case OperatorType::kTensorFlowSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole @@ -1734,6 +1772,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSparseToDenseOperator(model, static_cast(op)); break; + case OperatorType::kTensorFlowTile: + ProcessTileOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index ab24c4f9966a37995f23f600263fe96aba6da2d6..eca2c701f8bbf889088794c939af7082db0734dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -45,12 +45,14 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowMinimum || type == OperatorType::kTensorFlowMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || - type == OperatorType::kLogSoftmax || + type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || + type == OperatorType::kResizeBilinear || type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kPadV2 || type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc deleted file mode 100644 index 1ddf54c778cd1fae7a8fce0ecb97209274e71ac0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, - int operand_index) { - CHECK(tile_op->type == OperatorType::kTensorFlowTile); - CHECK_EQ(binary_op->inputs.size(), 2); - CHECK_EQ(tile_op->inputs.size(), 2); - const string tile_multiplier_array = tile_op->inputs[1]; - const string tile_output_array = tile_op->outputs[0]; - binary_op->inputs[operand_index] = tile_op->inputs[0]; - auto tile_it = model->operators.begin(); - for (; tile_it != model->operators.end(); ++tile_it) { - if (tile_it->get() == tile_op) { - break; - } - } - CHECK(tile_it != model->operators.end()); - CHECK(tile_it->get() == tile_op); - model->operators.erase(tile_it); - if (!CountOpsWithInput(*model, tile_multiplier_array) && - !GetOpWithOutput(*model, tile_multiplier_array)) { - model->EraseArray(tile_multiplier_array); - } - if (!CountOpsWithInput(*model, tile_output_array)) { - model->EraseArray(tile_output_array); - } -} -} // namespace - -bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { - const auto binary_it = model->operators.begin() + op_index; - auto* binary_op = binary_it->get(); - // Test for binary ops of types that we know how to resolve - if (binary_op->inputs.size() != 2) { - return false; - } - if (binary_op->type != OperatorType::kAdd && - binary_op->type != OperatorType::kMul && - binary_op->type != OperatorType::kSub && - binary_op->type != OperatorType::kDiv) { - return false; - } - - Operator* const op[2] = { - GetOpWithOutput(*model, binary_op->inputs[0]), - GetOpWithOutput(*model, binary_op->inputs[1]), - }; - - // In the unlikely case where both operands are Tile, we can't infer the - // output - // size without the Tile nodes, so we have to bail out. - if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && - op[1]->type == OperatorType::kTensorFlowTile) { - return false; - } - - for (int i = 0; i < 2; i++) { - if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { - // We can only remove a Tile operator is no other op than the present - // binary op was consuming its tiled output. - if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { - AddMessageF("Removing %s", LogName(*op[i])); - RemoveTileOperator(model, op[i], binary_op, i); - return true; - } - } - } - return false; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index a18748ae96f0a26bafdf13b7f33699fdb3195bd0..cd4f034dfea57b6d379b67a90ba4fa3fe3d615d5 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -63,8 +63,6 @@ using tensorflow::TensorShapeProto; namespace toco { -using port::Status; - namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -130,6 +128,42 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } +tensorflow::Status CheckOptionalAttr(const NodeDef& node, + const string& attr_name, + const string& expected_value) { + if (HasAttr(node, attr_name)) { + const string& value = GetStringAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + expected_value + "'"); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status CheckOptionalAttr( + const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { + if (HasAttr(node, attr_name)) { + const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); + } + } + return tensorflow::Status::OK(); +} + +template +tensorflow::Status ExpectValue(const T1& v1, const T2& v2, + const string& description) { + if (v1 == v2) return tensorflow::Status::OK(); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unexpected ", description, ": got ", v1, ", expected ", v2)); +} + ArrayDataType ConvertDataType(tensorflow::DataType dtype) { if (dtype == DT_UINT8) return ArrayDataType::kUint8; @@ -148,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - int* input_flat_size, Shape* shape) { +tensorflow::Status ImportShape( + const TFLITE_PROTO_NS::RepeatedPtrField& + input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -160,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // For now, tweaking this to record a 0-D shape instead. shape->mutable_dims()->clear(); if (input_flat_size != nullptr) *input_flat_size = 0; - return Status::OK(); + return tensorflow::Status::OK(); } // TensorFlow's shapes use int64s, while TOCO uses ints. if (d.size() > std::numeric_limits::max()) { - return Status(false, "Shape element overflows"); + return tensorflow::errors::InvalidArgument("Shape element overflows"); } input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; - if (input_flat_size == nullptr) return Status::OK(); + if (input_flat_size == nullptr) return tensorflow::Status::OK(); return NumElements(input_dims_only_sizes, input_flat_size); } -Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -203,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_float_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(float), ") nor float_val (", input_tensor.float_val_size(), ") have the right dimensions (", input_flat_size, ") for this float tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -236,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(uint8_t), ") nor int_val (", input_tensor.int_val_size(), ") have the right dimensions (", input_flat_size, ") for this uint8 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -269,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, - absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size() / sizeof(int32), - ") nor int_val (", input_tensor.int_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this int32 tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (", + input_tensor.int_val_size(), ") have the right dimensions (", + input_flat_size, ") for this int32 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -302,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(int64), ") nor int64_val (", input_tensor.int64_val_size(), ") have the right dimensions (", input_flat_size, ") for this int64 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -343,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status( - false, absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size(), - ") nor bool_val (", input_tensor.bool_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this bool tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -365,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { if (!status.ok()) return status; if (input_flat_size != input_tensor.string_val_size()) { - return Status(false, - "Input_content string_val doesn't have the right dimensions " - "for this string tensor"); + return tensorflow::errors::InvalidArgument( + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); } auto& output_string_data = @@ -377,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } - return Status::OK(); + return tensorflow::Status::OK(); } // Count the number of inputs of a given node. If @@ -417,14 +452,14 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -Status ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConstOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - Status status = Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -460,24 +495,21 @@ Status ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer(); break; } - if (!status.ok()) { - status.AppendMessage(" (while processing node '" + node.name() + "')"); - } - return status; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + status, " (while processing node '" + node.name() + "')"); + return tensorflow::Status::OK(); } -void ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); CheckInputsCount(node, tf_import_flags, 2); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (HasAttr(node, "data_format")) { - CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); - } - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -502,27 +534,26 @@ void ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; + if (!HasAttr(node, "strides")) { + return tensorflow::errors::InvalidArgument("Missing attribute 'strides'"); + } const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); - CHECK_EQ(strides.i(3), 1); + TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - CHECK_EQ(dilations.i_size(), 4); - CHECK_EQ(dilations.i(0), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; - CHECK_EQ(dilations.i(3), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; + TF_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); + } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); } else { @@ -535,9 +566,12 @@ void ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + return tensorflow::errors::InvalidArgument( + "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); + + return tensorflow::Status::OK(); } void ConvertDepthwiseConvOperator(const NodeDef& node, @@ -1408,11 +1442,13 @@ void ConvertTransposeConvOperator(const NodeDef& node, if (existing_transpose) { CHECK(existing_transpose->type == OperatorType::kTranspose); } else { - // Transpose weights from HWIO order to OHWI order, which is more efficient - // for computation + // Transpose weights from HWOI order to OHWI order, which is more efficient + // for computation. (Note that TensorFlow considers the order as HWIO + // because they consider this a backward conv, inverting the sense of + // input/output.) TransposeOperator* transpose = new TransposeOperator; string perm_array = CreateConstArray( - model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + model, node.name() + "_transpose_perm", {2, 0, 1, 3}); transpose->inputs = {weights_name, perm_array}; transpose->outputs = {transposed_weights_name}; model->operators.emplace_back(transpose); @@ -1714,15 +1750,15 @@ void ConvertSparseToDenseOperator(const NodeDef& node, } // namespace namespace internal { -Status ImportTensorFlowNode(const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ImportTensorFlowNode( + const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, Model* model) { // TODO(ahentz): Historically these functions all CHECK-fail on error. We've // been slowly converting them to return Status. if (node.op() == "Const") { return ConvertConstOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); + return ConvertConvOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2DBackpropInput") { ConvertTransposeConvOperator(node, tf_import_flags, model); } else if (node.op() == "DepthwiseConv2dNative") { @@ -1904,6 +1940,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertRandomUniform(node, tf_import_flags, model); } else if (node.op() == "Sin") { ConvertSimpleOperator(node, tf_import_flags, model); + } else if (node.op() == "Log") { + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "SparseToDense") { @@ -1917,7 +1955,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } - return Status::OK(); + return tensorflow::Status::OK(); } } // namespace internal diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 835676662b9cb7ed20e578e2a35747a64ba443dc..d18c329a43411236f8fd5446998c168803b9373a 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { -using port::Status; using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -33,6 +33,7 @@ using tensorflow::DT_INT64; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; using tensorflow::NodeDef; +using tensorflow::Status; namespace internal { Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, @@ -117,9 +118,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) { NodeDef node; BuildConstNode({1, -2, 10}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_EQ(status.error_message(), - "Tensor shape should not include negative values (while processing " - "node 'Node1')"); + EXPECT_EQ( + status.error_message(), + "Tensor shape should not include negative values\n\t (while processing " + "node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -129,7 +131,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) { BuildConstNode({3000000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Shape element overflows (while processing node 'Node1')"); + "Shape element overflows\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -139,7 +141,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) { BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Tensor shape is too large (while processing node 'Node1')"); + "Tensor shape is too large\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -148,11 +150,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { NodeDef node; BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_THAT( - status.error_message(), - ::testing::MatchesRegex( - "Neither input_content .0. nor .*_val .0. have the right " - "dimensions .8. for this .* tensor .while processing node 'Node1'.")); + EXPECT_THAT(status.error_message(), + ::testing::MatchesRegex( + "Neither input_content .0. nor .*_val .0. have the right " + "dimensions .8. for this .* tensor\n\t .while processing " + "node 'Node1'.")); } INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 81beb2937293efd0032fa736fa4b197df127d735..7bdec47aa9c1a960d0324c5f6a4b19f69cd056b2 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -155,6 +155,7 @@ enum class AxesOrder { k1HWO, // Our standard for DepthwiseConv weights kHWIM, // TensorFlow DepthwiseConv weights kNHWC, // TensorFlow activations + kHWOI, // TensorFlow back-prop conv weights }; // The type of the scalars in an array. @@ -1221,8 +1222,10 @@ struct TensorFlowSumOperator : Operator { }; // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. -// Not fully supported, just a placeholder to handle TensorFlow graphs and -// support graph transformations to other operator types by matching sub-graphs. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: int array with length of rank(input[0]) struct TensorFlowTileOperator : Operator { TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} }; @@ -1643,8 +1646,8 @@ struct SparseToDenseOperator : Operator { // be used for the transient array at hand. The 'start' and 'end' values are // offsets from the start of the workspace buffer, expressed in bytes. struct Alloc { - int start = 0; - int end = 0; + int64 start = 0; + int64 end = 0; }; inline bool operator<(const Alloc& a, const Alloc& b) { diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 0f104d5e2d02dc852a2720c78995108a00924298..4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags( "that information from the input file."), Flag("input_arrays", parsed_flags.input_arrays.bind(), parsed_flags.input_arrays.default_value(), - "Names of the output arrays, comma-separated. If not specified, " + "Names of the input arrays, comma-separated. If not specified, " "will try to read that information from the input file."), Flag("output_array", parsed_flags.output_array.bind(), parsed_flags.output_array.default_value(), diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index a954f1d6ba65f21cb99df226790f4bf4951581b1..93fe756a55d378fa205ff88be5e18aff586e5dca 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -12,6 +12,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", "//tensorflow/contrib/lite/toco:toco_port", "//tensorflow/contrib/lite/toco:toco_tooling", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc index 5b1db852b4f8e89c1a591cfe18a0ab0aa2db04c9..d93e104038741e6e59608f04115854d611f1f9ae 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/python/toco_python_api.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" @@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); if (error) return nullptr; - // Use toco to produce new outputs + // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { LOG(FATAL) << "Model proto failed to parse." << std::endl; @@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { LOG(FATAL) << "Toco proto failed to parse." << std::endl; } + + auto& dump_options = *GraphVizDumpOptions::singleton(); + if (toco_flags.has_dump_graphviz_dir()) { + dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); + } + if (toco_flags.has_dump_graphviz_include_video()) { + dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); + } + + // Convert model. std::unique_ptr model = toco::Import(toco_flags, model_flags, input_contents_txt); toco::Transform(toco_flags, model.get()); diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index a2d753657b0bf6c88f5c94a20a1240fb7c13a37c..7ba2603a952f6611e987901b735e9d4212f014ea 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -99,7 +99,8 @@ void LoadOperatorsMap( Offset>> ExportTensors( const Model& model, const details::TensorsMap& tensors_map, - FlatBufferBuilder* builder, std::vector* buffers_to_write) { + FlatBufferBuilder* builder, std::vector* buffers_to_write, + const std::set& variable_tensor_indices) { // In the end we will need to produce a vector sorted by the indices of the // tensors in the tensors_map. std::map> ordered_tensors; @@ -139,9 +140,11 @@ Offset>> ExportTensors( scale, zero_point); int index = tensors_map.at(tensor_name); + bool is_variable = + variable_tensor_indices.find(index) != variable_tensor_indices.end(); ordered_tensors[index] = CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, - builder->CreateString(tensor_name), q_param); + builder->CreateString(tensor_name), q_param, is_variable); } std::vector> tensor_vector; @@ -239,7 +242,10 @@ Offset>> ExportOperators( const Model& model, const std::map>& ops_by_type, const details::OperatorsMap& operators_map, - const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, + std::set* variable_tensor_indices) { + variable_tensor_indices->clear(); + // The operators are in execution order, so we just follow tf.mini order. std::vector> op_vector; for (const auto& op : model.operators) { @@ -256,18 +262,36 @@ Offset>> ExportOperators( int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); + auto tflite_op_it = ops_by_type.find(op->type); + BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() + ? nullptr + : tflite_op_it->second.get(); + // This is a custom op unless we can find it in ops_by_type, and even then // it could be a custom op (such as kTensorFlowUnsupported). - auto options = Options::Custom(0); - if (ops_by_type.count(op->type) != 0) { - options = ops_by_type.at(op->type)->Serialize(*op, builder); + + std::vector mutating_input_variables; + if (tflite_op) { + options = tflite_op->Serialize(*op, builder); + mutating_input_variables = tflite_op->GetMutatingInputVariables(*op); + + if (!mutating_input_variables.empty()) { + for (int i = 0; i < op->inputs.size(); ++i) { + if (!mutating_input_variables[i]) { + continue; + } + int32_t variable_tensor_index = tensors_map.at(op->inputs[i]); + variable_tensor_indices->insert(variable_tensor_index); + } + } } // The only supported CustomOptionFormat is FLEXBUFFERS now. op_vector.push_back(CreateOperator( *builder, op_index, builder->CreateVector(inputs), builder->CreateVector(outputs), options.type, options.builtin, - options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS, + builder->CreateVector(mutating_input_variables))); } return builder->CreateVector(op_vector); @@ -308,13 +332,10 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); - auto inputs = ExportInputTensors(model, tensors_map, &builder); - auto outputs = ExportOutputTensors(model, tensors_map, &builder); - std::set error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); + const string fake_quant_operation_name = "FAKE_QUANT"; if (error_summary.count(fake_quant_operation_name) != 0) { @@ -353,11 +374,18 @@ void Export( << absl::StrJoin(error_summary_final, ", ") << "."; } - auto ops = - ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + std::set variable_tensor_indices; + auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, + &builder, &variable_tensor_indices); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, + variable_tensor_indices); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); // TODO(aselle): add support to toco for multiple subgraphs. - auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops, + /* name */ 0); std::vector> subgraphs = {subgraph}; auto buffers = ExportBuffers(model, buffers_to_write, &builder); diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3..cb44a5e6d7356a1cf5597bbe48565c5b1e1949a6 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -113,15 +113,35 @@ void ImportOperators( << operators_table.size(); } string opname = operators_table.at(index); + + // Find and use the appropriate operator deserialization factory. + std::unique_ptr new_op = nullptr; if (ops_by_name.count(opname) == 0) { - LOG(FATAL) << "Op '" << opname << "' not supported"; + string effective_opname = "TENSORFLOW_UNSUPPORTED"; + if (ops_by_name.count(effective_opname) == 0) { + LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; + } + new_op = ops_by_name.at(effective_opname) + ->Deserialize(input_op->builtin_options(), + input_op->custom_options()); + if (new_op->type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported_op = + static_cast(new_op.get()); + unsupported_op->tensorflow_op = opname; + // TODO(b/109932940): Remove this when quantized is removed. + // For now, we assume all ops are quantized. + unsupported_op->quantized = true; + } else { + LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator"; + } + } else { + new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), + input_op->custom_options()); } - - auto new_op = ops_by_name.at(opname)->Deserialize( - input_op->builtin_options(), input_op->custom_options()); model->operators.emplace_back(new_op.release()); auto* op = model->operators.back().get(); + // Make sure all the inputs and outputs are hooked up. auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 8bfd76db6ea070a7019489d20ab54a4e6eb20179..a0fbb58acafbea72a0678754d1a6ae4275580e44 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -668,6 +668,24 @@ class Lstm : public BuiltinOperator GetMutatingInputVariables( + const Operator& op) const override { + const auto& lstm_op = static_cast(op); + + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + // TODO(ycling): Change the full kernel to use the new variable tensor + // design. This requires moving the state tensors from output to input. + return std::vector(); + case LstmCellOperator::KERNEL_BASIC: { + std::vector mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; + mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; + return mutating_input_variables; + } + } + } }; class Mean : public BuiltinOperator> BuildOperatorList() { "LESS", OperatorType::kTensorFlowLess)); ops.emplace_back(new SimpleOperator( "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + ops.emplace_back(new SimpleOperator( + "EQUAL", OperatorType::kTensorFlowEqual)); + ops.emplace_back(new SimpleOperator( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); + // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); - ops.emplace_back(new SimpleOperator( - "EQUAL", OperatorType::kTensorFlowEqual)); - ops.emplace_back(new SimpleOperator( - "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); + ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 5e9c20e40dd6274e0839379883b6dbe53064a0fc..d9ea23edf2b08146773ca58762623397e0f6257c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -87,6 +87,17 @@ class BaseOperator { // overridden. (See example in `operator_test.cc`) virtual int GetVersion(const Operator& op) const = 0; + // Given a Toco `Operator`, return a list of booleans indicating the op + // mutates which input variables. + // * If the op mutates any input variables, it should return a list of bool + // with the same length as inputs. + // * Otherwise, it will return an empty list. + virtual std::vector GetMutatingInputVariables( + const Operator& op) const { + // Most ops don't have variable tensors. This function can be overridden. + return std::vector(); + } + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 06bbe53516e296efdd0b12c0de06c30cf084b2c1..03bb20b3208196e964d950c0f0954d1fc0ba9e86 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -74,8 +74,10 @@ class OperatorTest : public ::testing::Test { auto new_toco_op = op.Deserialize(output_options->builtin_options(), output_options->custom_options()); - CHECK(dynamic_cast(new_toco_op.get())) - << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + CHECK(new_toco_op->type == toco_op.type) + << "The type of the serialized and deserialized" + << HelpfulOperatorTypeName(*new_toco_op) + << " does not match the type of the original " << HelpfulOperatorTypeName(toco_op); return std::unique_ptr(dynamic_cast(new_toco_op.release())); @@ -123,6 +125,7 @@ TEST_F(OperatorTest, SimpleOperators) { OperatorType::kTensorFlowEqual); CheckSimpleOperator( "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); + CheckSimpleOperator("LOG", OperatorType::kLog); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 4867c3a62e68406428644cd05bddf212008c2656..42c5d7e8ebc3a7b90963a92843af616d9e6532d6 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { switch (array_data_type) { case ArrayDataType::kFloat: return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt16: + return ::tflite::TensorType_INT16; case ArrayDataType::kInt32: return ::tflite::TensorType_INT32; case ArrayDataType::kInt64: @@ -109,6 +111,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { switch (::tflite::TensorType(tensor_type)) { case ::tflite::TensorType_FLOAT32: return ArrayDataType::kFloat; + case ::tflite::TensorType_INT16: + return ArrayDataType::kInt16; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; case ::tflite::TensorType_INT64: @@ -131,6 +135,8 @@ flatbuffers::Offset> DataBuffer::Serialize( switch (array.data_type) { case ArrayDataType::kFloat: return CopyBuffer(array, builder); + case ArrayDataType::kInt16: + return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); case ArrayDataType::kInt64: @@ -154,6 +160,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, switch (tensor.type()) { case ::tflite::TensorType_FLOAT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT16: + return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); case ::tflite::TensorType_INT64: diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 564f303b9bb41a777633ecabd666aa93ec3faefe..8c6ef95bfab0a5e9b410748eabf9570eec52c2e0 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -151,6 +151,12 @@ TEST(DataBuffer, Int32) { ::testing::ElementsAre(1, 1 << 30)); } +TEST(DataBuffer, Int16) { + Array recovered = ToFlatBufferAndBack({1, 1 << 14}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 14)); +} + TEST(DataBuffer, String) { Array recovered = ToFlatBufferAndBack( {"AA", "BBB", "Best. String. Ever."}); diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 4fe57879fb0f38a21aac01283bc68077aa4be771..ad4e94ded9f9730842a257e065d9aec2b1cbfac8 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -174,4 +174,13 @@ message TocoFlags { // Computation is still done in float, but reduces model size (at the cost of // accuracy and latency). optional bool quantize_weights = 20 [default = false]; + + // Full filepath of folder to dump the graphs at various stages of processing + // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order + // to keep the requirements of the output file. + optional string dump_graphviz_dir = 24; + + // Boolean indicating whether to dump the graph after every graph + // transformation. + optional bool dump_graphviz_include_video = 25; } diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index 3a5911c28dc5462b5d3747f6af6aa82026a23466..de76fd4032d24eff8a6c2fd0c16a911b9c00186b 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) @@ -61,8 +63,12 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { // Conversion to our wrapper Status. -Status ToStatus(const ::util::Status& uts) { - return Status(uts.ok(), uts.error_message()); +tensorflow::Status ToStatus(const ::util::Status& uts) { + if (!uts.ok()) { + return tensorflow::Status(tensorflow::errors::Code(uts.error_code()), + uts.error_message()); + } + return tensorflow::Status::OK(); } // Conversion to our wrapper Options. @@ -71,7 +77,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) { return Options(); } -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { File* f = nullptr; const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); if (f) { @@ -80,22 +86,24 @@ Status Writable(const string& filename) { return ToStatus(status); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { return ToStatus(::file::Readable(filename, ::file::Defaults())); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { auto status = ::file::Exists(filename, ::file::Defaults()); return ToStatus(status); } -Status GetContents(const string& filename, string* contents, - const file::Options& options) { +tensorflow::Status GetContents(const string& filename, string* contents, + const file::Options& options) { return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); } @@ -139,37 +147,42 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { FILE* f = fopen(filename.c_str(), "w"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not writable"); + return tensorflow::errors::NotFound("not writable"); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { FILE* f = fopen(filename.c_str(), "r"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not readable"); + return tensorflow::errors::NotFound("not readable"); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { struct stat statbuf; int ret = stat(filename.c_str(), &statbuf); - return Status(ret != -1, ""); + if (ret == -1) { + return tensorflow::errors::NotFound("file doesn't exist"); + } + return tensorflow::Status::OK(); } -Status GetContents(const string& path, string* output, - const file::Options& options) { +tensorflow::Status GetContents(const string& path, string* output, + const file::Options& options) { output->clear(); int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - return Status(false, "can't open() for read"); + return tensorflow::errors::NotFound("can't open() for read"); } // Direct read, for speed. @@ -180,25 +193,25 @@ Status GetContents(const string& path, string* output, if (size == 0) { // Done. close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } else if (size == -1) { // Error. close(fd); - return Status(false, "error during read()"); + return tensorflow::errors::Internal("error during read()"); } else { output->append(buffer, size); } } CHECK(0); - return Status(false, "internal error"); + return tensorflow::errors::Internal("internal error"); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); if (fd == -1) { - return Status(false, "can't open() for write"); + return tensorflow::errors::Internal("can't open() for write"); } size_t i = 0; @@ -207,13 +220,13 @@ Status SetContents(const string& filename, const string& contents, ssize_t written = write(fd, &contents[i], to_write); if (written == -1) { close(fd); - return Status(false, "write() error"); + return tensorflow::errors::Internal("write() error"); } i += written; } close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } string JoinPath(const string& base, const string& filename) { diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index b00b1e89e856190787d2d40096c9a5321bd80604..17f82b9dd7dcc633aa204038b6d965f4eb6967bb 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "google/protobuf/text_format.h" #include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" #if defined(PLATFORM_GOOGLE) @@ -54,26 +55,6 @@ double round(double x); namespace toco { namespace port { -class Status { - public: - static Status OK() { return Status(true, ""); } - - // Create a failed status with no message. - Status() {} - - Status(bool ok, const string& message) : ok_(ok), message_(message) {} - - void AppendMessage(const string& message) { message_ += message; } - - bool ok() const { return ok_; } - - const string error_message() const { return message_; } - - private: - bool ok_ = false; - string message_; -}; - void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); void CheckInitGoogleIsDone(const char* message); @@ -83,14 +64,14 @@ inline Options Defaults() { Options o; return o; } -Status GetContents(const string& filename, string* contents, - const Options& options); -Status SetContents(const string& filename, const string& contents, - const Options& options); +tensorflow::Status GetContents(const string& filename, string* contents, + const Options& options); +tensorflow::Status SetContents(const string& filename, const string& contents, + const Options& options); string JoinPath(const string& base, const string& filename); -Status Writable(const string& filename); -Status Readable(const string& filename, const Options& options); -Status Exists(const string& filename, const Options& options); +tensorflow::Status Writable(const string& filename); +tensorflow::Status Readable(const string& filename, const Options& options); +tensorflow::Status Exists(const string& filename, const Options& options); } // namespace file // Copy `src` string to `dest`. User must ensure `dest` has enough space. diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 1fe76f8163cdf23b27f8baaf2d9c6d99b1aa3747..3173d524b7fd043aeec72322875a39d2268ca3f6 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -56,6 +56,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); @@ -76,6 +77,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); @@ -94,7 +96,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); - transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); transformations->Add(new ResolveMultiplyByZero); transformations->Add(new IdentifyDilatedConv); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 5a82be39395c4505cf8ae893f531ab5f99fea417..92bab5246cb85052b5e0216f1cb8a04736ae7a79 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -585,6 +585,13 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } +bool IsValid(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + if (shape.dims(i) < 1) return false; + } + return true; +} + void CheckShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i @@ -1865,18 +1872,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI - (*shuffle)[0] = 1; - (*shuffle)[1] = 2; - (*shuffle)[2] = 3; - (*shuffle)[3] = 0; + *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO - (*shuffle)[0] = 3; - (*shuffle)[1] = 0; - (*shuffle)[2] = 1; - (*shuffle)[3] = 2; + *shuffle = {3, 0, 1, 2}; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWOI) { + *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } @@ -2022,6 +2026,8 @@ int AxesCount(AxesOrder axes_order) { return 4; case AxesOrder::kNHWC: return 4; + case AxesOrder::kHWOI: + return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 3b320e801349595396e573e225ffacf4c7607e52..7681ce9d39ec56f9447896682b52bd4efb1d0e54 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -32,8 +32,9 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" // TODO(aselle): Replace with using a container specific hash override instead. namespace std { @@ -112,7 +113,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +// Checks that all dimensions of 'shape' are at least 1. +bool IsValid(const Shape& shape); +// Same as above, but reports error using CHECK. void CheckShapeDimensions(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension @@ -315,7 +318,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output); // doesn't have enough range to represent the sum of elements, an error is // returned. template -port::Status NumElements(const std::vector& shape, U* num_elements) { +tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "vector type exceed capabilities of NumElements"); @@ -326,17 +329,17 @@ port::Status NumElements(const std::vector& shape, U* num_elements) { // TensorFlow's shapes sometimes include -1 to represent an "unknown" // size but TOCO isn't able to create arrays of unknown sizes and will // crash in RequiredBufferSizeForShape(). - return port::Status(false, - "Tensor shape should not include negative values"); + return tensorflow::errors::InvalidArgument( + "Tensor shape should not include negative values"); } if (static_cast(dim) > std::numeric_limits::max() / *num_elements) { *num_elements = 0; - return port::Status(false, "Tensor shape is too large"); + return tensorflow::errors::InvalidArgument("Tensor shape is too large"); } *num_elements *= dim; } - return port::Status::OK(); + return tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 87fd30db2cf54824a3c34ed875291d898f1a9e38..a683867374c8b8dcb274478adf6b5fa0691d1c5a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { @@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large"; TEST(NumElementsTest, Int) { int count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) { TEST(NumElementsTest, Int32) { int32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) { TEST(NumElementsTest, Int64) { int64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 32767}, &count); EXPECT_TRUE(status.ok()); @@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) { TEST(NumElementsTest, UnsignedInt32) { uint32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 2048, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) { TEST(NumElementsTest, UnsignedInt64) { uint64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 65535}, &count); diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index c5aa27d07c9a5cee0133b6ff99a8833a87d293d1..8857062c000201e1077469fc36e3bf2760924a30 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -6,8 +6,9 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -common_copts = ["-Wall"] +common_copts = ["-Wall"] + tflite_copts() cc_binary( name = "benchmark_model", @@ -16,13 +17,10 @@ cc_binary( "logging.h", ], copts = common_copts, - linkopts = select({ + linkopts = tflite_linkopts() + select({ "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm ], "//conditions:default": [], }), @@ -36,7 +34,6 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], copts = common_copts, - visibility = ["//visibility:private"], ) cc_test( @@ -59,7 +56,6 @@ cc_library( ], hdrs = ["benchmark_tflite_model.h"], copts = common_copts, - linkopts = tflite_linkopts(), deps = [ ":benchmark_model_lib", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md index e6f333aa5bb11449d5bf5d6c60cf77088649df8c..c10826afff6d5569545d4b7df73c88d24d9dcd1a 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -46,8 +46,6 @@ adb shell /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ --input_layer="Placeholder" \ --input_layer_shape="1,224,224,3" \ - --input_layer_type="uint8" \ - --output_layer="MobilenetV1/Predictions/Reshape_1" \ --num_threads=4 ``` @@ -66,8 +64,6 @@ bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ --graph=mobilenet_quant_v1_224.tflite \ --input_layer="Placeholder" \ --input_layer_shape="1,224,224,3" \ - --input_layer_type="uint8" \ - --output_layer="MobilenetV1/Predictions/Reshape_1" \ --num_threads=4 ``` @@ -93,80 +89,66 @@ This compiles TFLite with profiling enabled, now you can run the benchmark binar ============================== Run Order ============================== [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - CONV_2D 0.000 9.132 9.132 0.121% 0.121% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] - DEPTHWISE_CONV_2D 9.135 3.280 3.280 0.043% 0.165% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] - CONV_2D 12.419 6.877 6.877 0.091% 0.256% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] - DEPTHWISE_CONV_2D 19.299 1.708 1.708 0.023% 0.278% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] - CONV_2D 21.012 4.162 4.162 0.055% 0.334% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] - DEPTHWISE_CONV_2D 25.177 3.520 3.520 0.047% 0.380% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] - CONV_2D 28.701 10.218 10.218 0.136% 0.516% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] - DEPTHWISE_CONV_2D 38.922 0.827 0.827 0.011% 0.527% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] - CONV_2D 39.752 1.401 1.401 0.019% 0.545% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] - DEPTHWISE_CONV_2D 41.156 1.290 1.290 0.017% 0.563% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] - CONV_2D 42.448 5.995 5.995 0.080% 0.642% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] - DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.647% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 0.729% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.738% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.823% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.832% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] - CONV_2D 62.774 14.666 14.666 0.195% 1.026% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 1.035% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] - CONV_2D 78.081 7.186 7.186 0.095% 1.130% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - DEPTHWISE_CONV_2D 85.270 0.646 0.646 0.009% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] - CONV_2D 85.918 9.529 9.529 0.126% 1.265% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] - DEPTHWISE_CONV_2D 95.451 0.628 0.628 0.008% 1.273% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] - CONV_2D 96.081 2.077 2.077 0.028% 1.301% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] - DEPTHWISE_CONV_2D 98.162 0.168 0.168 0.002% 1.303% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] - CONV_2D 98.332 1.007 1.007 0.013% 1.317% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] - DEPTHWISE_CONV_2D 99.342 0.288 0.288 0.004% 1.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] - CONV_2D 99.632 8.197 8.197 0.109% 1.429% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] - AVERAGE_POOL_2D 107.832 0.045 0.045 0.001% 1.430% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] - CONV_2D 107.878 0.325 0.325 0.004% 1.434% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] - RESHAPE 108.206 0.003 0.003 0.000% 1.434% 0.000 0 [MobilenetV1/Predictions/Reshape] - SOFTMAX 108.211 0.038 0.038 0.001% 1.434% 0.000 0 [MobilenetV1/Predictions/Softmax] + CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax] ============================== Top by Computation Time ============================== [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - CONV_2D 62.774 14.666 14.666 0.195% 0.195% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - CONV_2D 28.701 10.218 10.218 0.136% 0.330% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] - CONV_2D 85.918 9.529 9.529 0.126% 0.456% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] - CONV_2D 0.000 9.132 9.132 0.121% 0.578% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] - CONV_2D 99.632 8.197 8.197 0.109% 0.686% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] - CONV_2D 78.081 7.186 7.186 0.095% 0.782% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - CONV_2D 12.419 6.877 6.877 0.091% 0.873% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.958% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 1.040% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - CONV_2D 42.448 5.995 5.995 0.080% 1.120% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] - -============================== Top by Memory Use ============================== - [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - SOFTMAX 108.211 0.038 0.038 0.001% 0.001% 0.000 0 [MobilenetV1/Predictions/Softmax] - RESHAPE 108.206 0.003 0.003 0.000% 0.001% 0.000 0 [MobilenetV1/Predictions/Reshape] - CONV_2D 78.081 7.186 7.186 0.095% 0.096% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 0.104% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] - CONV_2D 62.774 14.666 14.666 0.195% 0.299% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.307% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.393% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.401% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 0.483% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.489% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] Number of nodes executed: 31 ============================== Summary by node type ============================== [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] - CONV_2D 15 1.861 86.679% 86.679% 0.000 0 - DEPTHWISE_CONV_2D 13 0.286 13.321% 100.000% 0.000 0 + CONV_2D 15 1.406 89.270% 89.270% 0.000 0 + DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0 SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 RESHAPE 1 0.000 0.000% 100.000% 0.000 0 AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 -Timings (microseconds): count=50 first=108164 curr=128308 min=102850 max=197072 avg=150805 std=24368 +Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929 Memory (bytes): count=0 31 nodes observed -Average inference timings in us: Warmup: 135310, Init: 12123, no stats: 150988 - +Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9 ``` diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 2e5b86627322c2c64b8ef665a91595174a5dd8dd..5f803cec197858953180d379c763ed7ebd34ee1d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -123,29 +123,11 @@ void FillRandomString(tflite::DynamicBuffer* buffer, } } -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - bool PopulateInputLayerInfo( const string& names_string, const string& shapes_string, - const string& types_string, const string& values_string, std::vector* info) { std::vector names = Split(names_string, ','); std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); if (names.size() != shapes.size()) { TFLITE_LOG(ERROR) << "The number of items in" @@ -158,17 +140,6 @@ bool PopulateInputLayerInfo( << " --input_layer_shape=1,224,224,4:1,20"; return false; } - if (names.size() != types.size()) { - TFLITE_LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } for (int i = 0; i < names.size(); ++i) { info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); @@ -176,10 +147,6 @@ bool PopulateInputLayerInfo( input.name = names[i]; - input.data_type = TfLiteTypeFromString(types[i]); - TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) << "Incorrect size string specified: " << shapes[i]; for (int dim : input.shape) { @@ -190,12 +157,6 @@ bool PopulateInputLayerInfo( return false; } } - - if (i < values.size()) { - TFLITE_BENCHMARK_CHECK( - SplitAndParse(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } } return true; @@ -209,10 +170,6 @@ std::vector BenchmarkTfLiteModel::GetFlags() { Flag("graph", &graph, "graph file name"), Flag("input_layer", &input_layer_string, "input layer names"), Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), Flag("use_nnapi", &use_nnapi, "use nnapi api")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); @@ -224,8 +181,6 @@ void BenchmarkTfLiteModel::LogFlags() { TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; - TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]"; - TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]"; TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; } @@ -236,8 +191,7 @@ bool BenchmarkTfLiteModel::ValidateFlags() { return false; } return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs); + &inputs); } uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { @@ -293,8 +247,6 @@ void BenchmarkTfLiteModel::Init() { TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) << "Tensor # " << i << " is named " << t->name << " but flags call it " << input.name; - TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; } // Resize all non-string tensors. diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index e70f6de1bf461f4e946ec83d8eea83ff4a15bfca..ffb93da964b2da0328616e749abd9c5a84189468 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -64,10 +64,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { struct InputLayerInfo { std::string name; - TfLiteType data_type; std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; }; private: @@ -78,7 +75,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::string input_layer_type_string; std::string input_layer_shape_string; std::string input_layer_values_string; - std::string output_layer_string; std::vector inputs; bool use_nnapi; ProfilingListener profiling_listener_; diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 4f2c82ca23011667662c74507fcbd99bcde4c7c0..66cb493e5c5bb9b8645e87dc7f5b274d916f64fc 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -77,7 +77,31 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], - shard_count = 16, + shard_count = 30, + srcs_version = "PY2AND3", + tags = ["noasan"], # times out b/63678675 + deps = [ + ":metrics_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "metric_ops_large_test", + size = "large", + srcs = ["python/ops/metric_ops_large_test.py"], srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7acfc383eb9a659a600752cf57b4978daa8a07bc --- /dev/null +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Large tests for metric_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testLargeCase(self): + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b13f08a37d9e856d56903324fc6e7cf1457bb191..e720097636fdbe767ca3180345ecd93504c89d55 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2391,34 +2391,6 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) - def testLargeCase(self): - self.skipTest("Test consistently timing out") - shape = [32, 512, 256, 1] - predictions = random_ops.random_uniform( - shape, 0.0, 1.0, dtype=dtypes_lib.float32) - labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) - - result, update_op = metric_ops.precision_recall_at_equal_thresholds( - labels=labels, predictions=predictions, num_thresholds=201) - # Run many updates, enough to cause highly inaccurate values if the - # code used float32 for accumulation. - num_updates = 71 - - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - for _ in xrange(num_updates): - sess.run(update_op) - - prdata = sess.run(result) - - # Since we use random values, we won't know the tp/fp/tn/fn values, but - # tp and fp at threshold 0 should be the total number of positive and - # negative labels, hence their sum should be total number of pixels. - expected_value = 1.0 * np.product(shape) * num_updates - got_value = prdata.tp[0] + prdata.fp[0] - # They should be at least within 1. - self.assertNear(got_value, expected_value, 1.0) - def _testCase(self, predictions, labels, @@ -4727,199 +4699,204 @@ class StreamingSparseRecallTest(test.TestCase): self._test_sparse_recall_at_top_k( labels, top_k_predictions, expected=1.0 / 2) - def test_one_label_at_k1_weighted(self): + def _test_one_label_at_k1_weighted(self, labels): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], - [0, 0, 1, 0]]) - dense_labels = np.array([[3], [2]], dtype=np.int64) - for labels in (sparse_labels, dense_labels): - # Class 3: 1 label, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) + # Class 3: 1 label, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) - # All classes: 2 labels, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + # All classes: 2 labels, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + + def test_one_label_at_k1_weighted_sparse_labels(self): + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) + self._test_one_label_at_k1_weighted(sparse_labels) + + def test_one_label_at_k1_weighted_dense_labels(self): + dense_labels = np.array([[3], [2]], dtype=np.int64) + self._test_one_label_at_k1_weighted(dense_labels) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index e4e5ccc33472ad5a12bd8111fb1ff6ebbd6f45f9..ef34f7bf7bf3eba047b50ce8abf883b0ed741a63 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer class LossScaleOptimizer(optimizer.Optimizer): + # TODO(jamesqin): move mixed precision training explanation to __init__ + # docstring. """An optimizer that applies loss scaling in backprop. - This class is useful for mixed precision training on GPUs (or other potential - accelerators), which is an approach to improve compute throughput without loss - of model quality. - - The commmon configuration of mixed precision models is the following: - * variables are kept in high precision (e.g. float32). - * computations are done in lower precision (e.g. float16). variables are - casted to lower precision before they're used. - * (in training), final gradients are casted back to variable precision and get - applied. - - Because computations happen in lower precision, gradients in the backprop pass - might underflow in the smaller dynamic range, causing a model to converge at a - suboptimal level. This optimizer multiplies the loss by a factor before - backprop starts to prevent underflow. Before gradients are applied, they are - casted to higher precision and down-scaled by the same factor, so - mathematically the variable updates are no different from regular - same-precision training. + This class is useful for "mixed precision training" on GPUs (or other + potential accelerators), an approach to improve compute throughput without + compromising model quality. + + The canonical way to perform mixed precision training is the following: + * Model variables are kept in high precision (e.g. float32). + * Computations are done in lower precision (e.g. float16), which enjoys + performance speedup by virtue of hardware support. Variables are casted to + lower precision before they're used. + * Final gradients are casted back to high precision dtype, then used to update + variables. + + The side-effect of performing computation in lower precision, is that it comes + with smaller numerical range. During backproping, small gradients might + underflow in the reduced numerical range, causing a model to converge at + suboptimal level. + + To prevent underflow, this optimizer multiplies the loss by a factor before + backprop starts. Consequently, the gradients are linearly scaled up by the + same factor, thus not falling into the underflow zone. After that, to perserve + the correctness of backprop, the gradients are down-scaled by the same factor, + casted to the (higher) variable precision, then applied on the variables. See [Nvidia's manual on mixed precision training]( https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 13aa1d7e7a11877373a848c1ba865aa418790cd0..114b344d38413208755a47f36f45badc1a5ecaa9 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -28,6 +28,7 @@ py_library( "python/training/reg_adagrad_optimizer.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -194,6 +195,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 4c13c8e247185213b798eb733ddcf65a07a8f64d..5df5d35f8e4f8fcc2c5aa09bd8f3254e16e3a74f 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -27,6 +27,7 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * @@ -46,6 +47,10 @@ _allowed_symbols = [ 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.training import optimizer +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import resource_variable_ops + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + decay_var_list: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + return scatter_add(var, indices, -self._weight_decay * var, + self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..74d1cdbbdac8724518937d141a976abf9fec6ce3 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.contrib.opt.python.training import weight_decay_optimizers + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return AdamW(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 9305c6a11c4ec898c82553773e8e7277a54ab82e..85918bf8506623cf5e0c9106ae9ed80e233f5a7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index d389050e67f9a9e48b91583e5088058ec4e2832f..06553929dc44ca1f75ce64532a4dcdf1c8aae3eb 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum") .Input("input: T") .Output("output: T") .Attr("T: {bfloat16, float}") + .Attr("group_assignment: list(int) = []") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each -instance supplies its own input, and the output of each is the sum of -all the inputs. +instance supplies its own input. If group_assignment is empty, the output of +each is the sum of all the inputs, otherwise the output of each is the sum of +the inputs belonging to the same group. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1. +Thus we get the outputs: `[A+C, B+D, A+C, B+D]`. input: The local input to the sum. output: The sum of all the distributed inputs. T: The type of elements to be summed. +group_assignment: The list of group ids. `group_assignment[i]` represents the + group id of replica i. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index ab2a7a0d4bec48d6b3b459bb3144e8ddae614ca0..f632c953c85fcc335410c10db785265af9d8ddf3 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -44,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput") " with other shapes."); } c->set_output(0, cur); + + // If this is a resource, unify the resource shapes. + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + if (dtype == DT_RESOURCE) { + const std::vector* shapes_and_types = + nullptr; + for (int i = c->num_inputs() - 1; i >= 0; --i) { + if (shapes_and_types) { + if (!c->MergeInputHandleShapesAndTypes(i, *shapes_and_types)) { + return errors::InvalidArgument( + "Incompatible resource shapes for replicated TPU input."); + } + } else { + shapes_and_types = c->input_handle_shapes_and_types(i); + } + } + if (shapes_and_types) { + c->set_output_handle_shapes_and_types(0, *shapes_and_types); + } + } return Status::OK(); }) .Doc( diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index dbf1ab6bbf0ddc7429d8e19279451eb862981e0c..38d1c3049ef7185f2f9f448361029d066678cdae 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -49,11 +49,11 @@ tf_cc_binary( ":tpu_profiler_analysis_proto_cc", ":tpu_profiler_proto_cc", ":version", + "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 99485322c6b9434f4c1700b9e2a6af00a65f794f..f80f5652af79d410946971573ae160fdd0b85f6d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,7 +18,7 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpc++/grpc++.h" +#include "grpcpp/grpcpp.h" #include #include diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 508c7a842fb82ec080082d7e7f02f8d2f2a79447..7f1d25732e21b5dea4e605f6caa141ca9d3d02c6 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -35,19 +35,19 @@ flags.DEFINE_string( None, help='GCE zone where the Cloud TPU is located in. If not specified, we ' 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string('tpu_name', None, +flags.DEFINE_string('tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' 'specify either this flag or --service_addr.') # Tool specific parameters flags.DEFINE_string( 'service_addr', None, 'Address of TPU profiler service e.g. ' - 'localhost:8466, you must specify either this flag or --tpu_name.') + 'localhost:8466, you must specify either this flag or --tpu.') flags.DEFINE_string( 'workers_list', None, 'The list of worker TPUs that we are about to profile' - ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or ' + ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or ' '--service_addr to profile a subset of tpu nodes. You can also use only' - '--tpu_name and leave this flag unspecified to profile all the tpus.') + '--tpu and leave this flag unspecified to profile all the tpus.') flags.DEFINE_string('logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' 'gs://tb_bucket') @@ -76,19 +76,19 @@ def run_main(): def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.service_addr is None and FLAGS.tpu_name is None: - sys.exit('You must specify either --service_addr or --tpu_name.') + if FLAGS.service_addr is None and FLAGS.tpu is None: + sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr is not None: - if FLAGS.tpu_name is not None: - tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' - '--tpu_name and using --service_addr.') + if FLAGS.tpu is not None: + tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' + '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: tpu_cluster_resolver = ( tf.contrib.cluster_resolver.TPUClusterResolver( - [FLAGS.tpu_name], + [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index ebd478fd02295108b9d2454963eb06165828b523..f97a972f01a3ba5582df3675439aa962886f796e 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.6.0' +_VERSION = '1.7.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', @@ -46,7 +46,7 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 618479e1a6ccf26a4103ea1f182b662d7d9998da..bd9ba6697edd9ef14dd3af0d2c9b77df9ec6917a 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.6.0" +#define TPU_PROFILER_VERSION "1.7.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 14c63a79763300dcfe8d6c8e09b90f8e9c772358..bf442d9116d2ceca499ffc66258c64b5b94dd881 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,9 +38,8 @@ if platform.system() != "Windows": @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): - del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. - return gen_tpu_ops.cross_replica_sum(grad) + return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment")) # This extra type checking exists to give a more helpful error message in # the common case that uint8 and int64 values are infed. Remove when both diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 1c482950e64a9537a2996df66ed9403e53cf8a71..dc473c5846aafc5a92756dfb8259f7f8dc14b98d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -591,16 +591,22 @@ def split_compile_and_replicate(computation, with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): - # The EncapsulateTPUComputations rewrite needs to identify the - # replicated arguments inside each computation. Adds identity operators - # tagged with an attribute _tpu_replicated_input to identify the - # replicated inputs. + # For backward compatibility reasons, we tag replicated inputs with the + # _tpu_replicated_input attribute. This does nothing and exists only for + # backward compatibility. + # TODO(phawkins): delete the attr_scope after 6/28/2018. # pylint: disable=protected-access - with graph._attr_scope({"_tpu_replicated_input": - attr_value_pb2.AttrValue(b=True)}): + with graph._attr_scope({ + "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True) + }): + # Add identity ops so even unused inputs are "consumed" by the + # computation. This is to avoid orphaned TPUReplicatedInput nodes. + # TODO(phawkins): consider instead pruning unused TPUReplicatedInput + # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs)] + for i, x in enumerate(computation_inputs) + ] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the @@ -623,15 +629,16 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) - # If the computation returns `None`, add `no_op` here so that when user - # fetches `no_op` returned by this function, the TPUExecute node will be - # triggered. + # If the computation returns `None`, make it an empty tuple. if outputs is None: - outputs = (control_flow_ops.no_op(),) + outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) try: with ops.device(core(0)): outputs = [ diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 64ae35dfc5e6d385a23c2dba15562d71aae4d497..e94bd78833f6cbe9adb1b6ca3f29a88bd8a53f64 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1343,8 +1343,55 @@ class _ModelFnWrapper(object): key, tensor)) return predictions + def _validate_model_features_and_labels(self, + features, + labels, + is_export_mode): + """Validates that the features and labels for the model function are valid. + + A valid features/labels object is the one with: + - Type: Tensor or a dictionary of Tensors + - Static shape if is_export_mode is False. + + Args: + features: the features that would be input to the model function. + labels: the labels that would be input to the model function. + is_export_mode: boolean value specifying if in export mode. + + Raises: + TypeError: If features/labels are not of the correct type. + ValueError: If features/labels have dynamic shape. + """ + + def validate(obj, obj_name): + """Helper validate function.""" + if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): + raise TypeError( + 'The {} to the model returned by input_fn must be either a Tensor ' + 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, + obj)) + if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): + return + if isinstance(obj, ops.Tensor): + if not obj.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static shape.' + ' Tensor: {}'.format(obj_name, obj)) + else: + for (key, tensor) in obj.items(): + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) + + validate(features, 'features') + if labels is not None: + validate(labels, 'labels') + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" + self._validate_model_features_and_labels(features, labels, is_export_mode) model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} @@ -1855,11 +1902,6 @@ class TPUEstimator(estimator_lib.Estimator): ... ``` - Current limitations: - -------------------- - - 1. Outside compilation does not work yet (b/79991729). - """ def __init__(self, @@ -2078,10 +2120,21 @@ class TPUEstimator(estimator_lib.Estimator): # Reconstruct `tensors`, but with `tpu_tensors` replaced with # `tpu_tensors_on_cpu`. - new_tensors = [ - tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t - for t in tensors - ] + new_tensors = [] + for t in tensors: + if _is_tpu_tensor(t): + new_tensors.append(tpu_tensors_on_cpu.pop(0)) + elif t is None: + new_tensors.append(None) + else: + # Only fetching `tpu_tensors_on_cpu` does not trigger + # TPU computation and blocks, so we add the control dependency here. + control_inputs = (tpu_tensors_on_cpu + if isinstance(tpu_tensors_on_cpu, (list, tuple)) + else (tpu_tensors_on_cpu,)) + with ops.control_dependencies(control_inputs): + new_tensors.append(array_ops.identity(t)) + # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) # Reconstruct `export_outputs`. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index e76cf83e4ddcd86ab3971bcecefe2e2dc979bf63..15f99d7eebddd46f9f6902b68f01e42359a72cbe 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.ops.losses import losses @@ -32,7 +34,8 @@ class CrossShardOptimizer(optimizer.Optimizer): def __init__(self, opt, reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer"): + name="CrossShardOptimizer", + group_assignment=None): """Construct a new cross-shard optimizer. Args: @@ -40,6 +43,8 @@ class CrossShardOptimizer(optimizer.Optimizer): reduction: The reduction to apply to the shard losses. name: Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer". + group_assignment: Optional list of group ids for applying the optimizer + to subgroups. Raises: ValueError: If reduction is not a valid cross-shard reduction. @@ -50,6 +55,35 @@ class CrossShardOptimizer(optimizer.Optimizer): super(CrossShardOptimizer, self).__init__(False, name) self._opt = opt self._reduction = reduction + self._group_assignment = group_assignment + + def _verify_and_get_subgroup_size(self, group_assignment, num_shards): + """Verify group_assignment and get the subgroup size". + + Args: + group_assignment: list of group ids for applying the optimizer + to subgroups. + num_shards: The number of TPU shards. + + Returns: + The size of one subgroup in group_assignment. + + Raises: + ValueError: If group_assignment is invalid. + """ + if not group_assignment: + return None + if len(group_assignment) != num_shards: + raise ValueError("The size of group_assignment does not equal to " + "num_shard({0}). Got group_assignment={1}".format( + num_shards, self._group_assignment)) + subgroup_size_list = dict(collections.Counter(group_assignment)).values() + if all(subgroup_size_list[0] == size for size in subgroup_size_list): + return subgroup_size_list[0] + else: + raise ValueError("The size of each subgroup in group_assignment must " + "be equal. Got group_assignment={}".format( + self._group_assignment)) def compute_gradients(self, loss, var_list=None, **kwargs): """Compute gradients of "loss" for the variables in "var_list". @@ -71,7 +105,8 @@ class CrossShardOptimizer(optimizer.Optimizer): A list of (gradient, variable) pairs. Raises: - ValueError: If not within a tpu_shard_context. + ValueError: If not within a tpu_shard_context or group_assignment is + invalid. """ num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -79,9 +114,17 @@ class CrossShardOptimizer(optimizer.Optimizer): "CrossShardOptimizer should be used within a tpu_shard_context, but " "got unset number_of_shards. Assuming 1.") num_shards = 1 + + subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, + num_shards) + if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - scale = 1.0 / num_shards + if self._group_assignment: + scale = 1.0 / subgroup_size + else: + scale = 1.0 / num_shards loss *= scale + return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) def apply_gradients(self, grads_and_vars, global_step=None, name=None): @@ -110,7 +153,8 @@ class CrossShardOptimizer(optimizer.Optimizer): if grad is None: summed_grads_and_vars.append((grad, var)) else: - summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) + summed_grads_and_vars.append((tpu_ops.cross_replica_sum( + grad, self._group_assignment), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) def get_slot(self, *args, **kwargs): diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c1ec37003eb98f000f6cf8918234c5d..a2444934bc21d58ed57d15494b3548a31ce3a2df 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 9720fd6e8657de18cf8d7565f834568ae52fdbda..19cb8983b6836266ebfac70c54657a96324e8435 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -53,12 +53,12 @@ cc_library( ":grpc_verbs_service_impl", ":rdma_mgr", ":verbs_service_proto_cc", + "//tensorflow:grpc++", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++_unsecure", + "//tensorflow:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 742f946c9536973eb8a6a11afda1b32ae4a7726b..af29abd91feda22824e57c19c13a3f48fb1d61b7 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpc++/alarm.h" -#include "grpc++/grpc++.h" -#include "grpc++/server_builder.h" +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a..4da7b59c69c88a4d04be37543aae7f03decd2c52 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/channel_interface.h" -#include "grpc++/impl/codegen/client_unary_call.h" -#include "grpc++/impl/codegen/method_handler_impl.h" -#include "grpc++/impl/codegen/rpc_service_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/channel_interface.h" +#include "grpcpp/impl/codegen/client_unary_call.h" +#include "grpcpp/impl/codegen/method_handler_impl.h" +#include "grpcpp/impl/codegen/rpc_service_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/sync_stream.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 1f0f10517e98a32ae882c027330091928f1a6ee2..abe5e08b07cd71b7ca28321e6eb2cf0eec5d1b0f 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/proto_utils.h" -#include "grpc++/impl/codegen/rpc_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/status.h" -#include "grpc++/impl/codegen/stub_options.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5de59eaef7cd9924696bab3586521e7ba04f972b..b6b48a077cdafe12aeb1e4e0988493692c82eace 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -232,7 +232,6 @@ tf_proto_library( name = "protos_all", srcs = [], cc_api_version = 2, - dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, @@ -879,6 +878,7 @@ cc_library( hdrs = [ "util/stats_calculator.h", ], + copts = tf_copts(), ) cc_library( @@ -2237,7 +2237,6 @@ tf_proto_library( name = "error_codes_proto", srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, - dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, @@ -2260,7 +2259,6 @@ tf_proto_library( name = "protos_all_proto", srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, - dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, @@ -2636,6 +2634,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/lower_if_op.h", @@ -2685,6 +2684,7 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d643880966f7607405539a5ad43d8e03dc13fb..06b797e32edc046bab498f8d775040d57ef62ce9 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0c5b1eb45af6812bdd35e2fef43ac8c02a5b9388 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt @@ -0,0 +1,18 @@ +op { + graph_op_name: "BatchDatasetV2" + visibility: HIDDEN + in_arg { + name: "batch_size" + description: <