diff --git a/configure.py b/configure.py index 14fca1f73236eb01ec4bc24499544453fb0807f8..61fa9feaded7e98c531b620891443ba77f182e9b 100644 --- a/configure.py +++ b/configure.py @@ -55,6 +55,12 @@ NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' ] +# List of files to be configured for using Bazel on Apple platforms. +APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/objc/BUILD', + 'tensorflow/lite/experimental/swift/BUILD' +] + if platform.machine() == 'ppc64le': _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' else: @@ -1534,6 +1540,23 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) +def configure_apple_bazel_rules(): + """Configures Bazel rules for building on Apple platforms. + + Enables analyzing and building Apple Bazel rules on Apple platforms. This + function will only be executed if `is_macos()` is true. + """ + if not is_macos(): + return + for filepath in APPLE_BAZEL_FILES: + print( + 'Configuring %s file to analyze and build Bazel rules on Apple platforms.' + % filepath) + existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') + renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) + os.rename(existing_filepath, renamed_filepath) + + def main(): global _TF_WORKSPACE_ROOT global _TF_BAZELRC @@ -1574,6 +1597,8 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' + else: + environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1676,6 +1701,14 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES', + 'Configure Bazel rules for Apple platforms', False, + ('Would you like to configure Bazel rules for building on Apple platforms?' + ), 'Configuring Bazel rules for Apple platforms.', + 'Not configuring Bazel rules for Apple platforms.'): + configure_apple_bazel_rules() + print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0b63ee4056c57a58aa22560bea63dd3fac623602..f53982f1efc9885cc12dcc672ad819c762aca378 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -462,8 +462,7 @@ tf_cc_shared_object( "//tensorflow:darwin": [], "//tensorflow:windows": [], "//conditions:default": [ - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_framework_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", ], }), linkstatic = 1, @@ -497,15 +496,13 @@ tf_cc_shared_object( name = "libtensorflow.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/c:exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/c:version_script.lds)", + "-Wl,--version-script,$(location //tensorflow/c:version_script.lds)", ], }), visibility = ["//visibility:public"], @@ -523,14 +520,12 @@ tf_cc_shared_object( name = "libtensorflow_cc.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow:tf_exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], }), visibility = ["//visibility:public"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index a6eb4755f32d2504ae1aab747f110d68d72a0d5f..ddcacfcbe2d4d8b089f10f1a771384dc8c4fd199 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -26,14 +26,28 @@ import sys as _sys # API IMPORTS PLACEHOLDER +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +# We're using bitwise, but there's nothing special about that. +_API_MODULE = bitwise # pylint: disable=undefined-variable +_current_module = _sys.modules[__name__] +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg="Limited tf.summary API due to missing TensorBoard installation") _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api._v2.estimator')) -_current_module = _sys.modules[__name__] if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, @@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v2.keras')) -# Make sure directory containing top level submodules is in -# the __path__ so that "from tensorflow.foo import bar" works. -# We're using bitwise, but there's nothing special about that. -_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if not hasattr(_current_module, '__path__'): - __path__ = [_tf_api_dir] -elif _tf_api_dir not in __path__: - __path__.append(_tf_api_dir) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index eeca8f0d566a6401cb64e4fe3f0ee3c5aeb4ece2..5eb25a81b7f765f551bc4f1b7ba99b35dbc6b7bb 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index ef22b67fe95364a0513ebd7a59d116a2d78cc2e9..245d7ba2b186895532953aa61ebfc3fc6bf635a7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -641,7 +641,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out) { if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 88b8b49b016efec4eb21271d275e50c786e1e602..051de3a7dc0f8c630b6c81d2cfa960e5279c93c0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1314,6 +1314,28 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + // Returns the name of the graph function. // The return value points to memory that is only usable until the next // mutation to *func. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index a8325ce494c4f57fcd7e64b2d233ee4e6666bc4e..7ff4084decc686b067226ecaecf2af29d51d42f2 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -9064,11 +9064,6 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); auto* desc = TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); - for (auto* input : op->operation.Inputs()) { - auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status); - if (!status->status.ok()) return nullptr; - TF_AddInput(desc, symbolic_input); - } VLOG(1) << "Adding attrs."; tensorflow::AttrValueMap attrs; @@ -9077,6 +9072,34 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, desc->node_builder.Attr(attr.first, attr.second); } + VLOG(1) << "Adding inputs."; + const auto& inputs = op->operation.Inputs(); + size_t inputIndex = 0; + const tensorflow::OpDef& op_def = desc->node_builder.op_def(); + for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { + // TODO(bgogul): Add support for number attributes. + DCHECK(input_arg.number_attr().empty()) + << "Number attributes is not implemented yet."; + if (input_arg.type_list_attr().empty()) { + auto symbolic_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + TF_AddInput(desc, symbolic_input); + continue; + } + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + std::vector list_inputs(attr_value.list().type_size()); + for (TF_Output& list_input : list_inputs) { + list_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + } + TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + } + auto* graph_op = TF_FinishOperation(desc, status); if (!status->status.ok()) return nullptr; diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 354ee5f49f373edbc10e7706aa8776f3cc2a17cd..c54021a7517ebbdd00405cbfa9cee8f3f6616cca 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -446,5 +446,29 @@ TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { TFE_DeleteOp(squeeze); } +TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { + TFE_TensorHandle* scalar = TestScalarTensorHandle(); + TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + constexpr size_t kNumInputs = 3; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(identityn, scalar, status_); + } + TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT}; + TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs); + AddEagerOpToGraphAndCheck( + identityn, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(scalar); + TFE_DeleteOp(identityn); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 45d6c33a1e7053451d1dbadff480cf300ea4abbb..03d65ecefd4a9ba5a23a94ed902dfba6dd4fbda9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -272,10 +272,17 @@ Status FillFunctionBody( } } if (!node_attr_def) { +#ifdef TENSORFLOW_LITE_PROTOS + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name); +#else return errors::Unimplemented( "Placeholder value is not supported for attributes not in OpDef. " "Attribute: ", node_attr_name, ", OpDef: ", node->op_def().DebugString()); +#endif } OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); attr_def->set_name(func_attr_name); @@ -295,6 +302,8 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, const std::vector& inputs, const std::vector& outputs, const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, const char* description, FunctionDef* fdef) { if (!output_names.empty()) { DCHECK_EQ(output_names.size(), outputs.size()); @@ -418,6 +427,29 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, fdef->mutable_signature()->set_name(fn_name); } + if (!control_output_names.empty() && + (control_outputs.size() != control_output_names.size())) { + return InvalidArgument( + "Expected number of control outputs (", control_outputs.size(), + ") and the number of control output names (", + control_output_names.size(), ") to match but they do not."); + } + std::unordered_set control_output_names_set; + for (int i = 0; i < control_outputs.size(); ++i) { + string signature_name; + if (!control_output_names.empty()) { + signature_name = control_output_names[i]; + } else { + signature_name = control_outputs[i]->name(); + } + if (!control_output_names_set.insert(signature_name).second) { + return errors::InvalidArgument("Repeated control output name: ", + signature_name); + } + fdef->mutable_signature()->add_control_output(signature_name); + (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name(); + } + return Status::OK(); } @@ -525,14 +557,14 @@ Status ComputeBodyNodes( using tensorflow::Node; using tensorflow::string; -TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, - int num_opers, const TF_Operation* const* opers, - int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, - const char* const* output_names, - const TF_FunctionOptions* opts, - const char* description, TF_Status* status) { +TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); // Process inputs. @@ -557,19 +589,34 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, } } + // Process control output names. + std::vector control_output_names_vec; + if (control_output_names) { + control_output_names_vec.reserve(ncontrol_outputs); + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_names_vec.push_back(string(output_names[i])); + } + } + // Compute body nodes. std::vector body_nodes; status->status = tensorflow::ComputeBodyNodes( fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); if (!status->status.ok()) return nullptr; + // Compute body nodes. + std::vector control_output_nodes; + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_nodes.push_back(&control_outputs[i]->node); + } + // Do the actual function creation. TF_Function* tf_function = new TF_Function(); DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, - input_tensors, output_tensors, output_names_vec, description, - &tf_function->fdef); + input_tensors, output_tensors, output_names_vec, control_output_nodes, + control_output_names_vec, description, &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; @@ -577,6 +624,20 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { + return TF_GraphToFunctionWithControlOutputs( + fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, + inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, + description, status); +} + const char* TF_FunctionName(TF_Function* func) { return func->fdef.signature().name().c_str(); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 73283d775639b297857b2a50007dc7c28b1f39a3..d520b6b76849e562def6abd8be0510d3b4797e8c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -204,7 +204,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); -Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); // Set the shapes and types of the output's handle. // diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index af13f487af91594fedd4d5f77592682a6f98c34f..45701c7fcf02d5e6ec464ae10d4d20f20ba1d9f0 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -762,11 +762,13 @@ unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(true); + ctx->context.SetShouldStoreGraphs(true); + ctx->context.SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(false); + ctx->context.SetShouldStoreGraphs(false); + ctx->context.SetShouldStoreStepStats(false); } } // extern "C" diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 06bbb4ac41256524b566657105cc5d5858234405..af7f1bbf8aa5636d78c222f5ba95624054273c47 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -68,3 +68,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { // terminating the main thread. tensorflow::StartProfilerServer(&context->profiler_context, port).release(); } + +void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(true); +} + +void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(false); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 51a5fa0d816a179bb52940f6d8aead867ec9a267..eb57077e6834354005fbf7913cf5f51db3087b07 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -67,6 +67,14 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext( TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port); +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a09becc49b10d2c58f98fbcc11df5190f794c1d4..4c4d587fce04d101b3cc8faebcc3ba04f2f1d0cf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -150,6 +150,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ], ) @@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) +tf_gen_op_wrappers_cc( + name = "tpu_ops", + include_internal_ops = 1, + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + cc_library_with_android_deps( name = "cc_op_gen_main", srcs = [ diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 52345a376cc29ee47ccb9888c9bb26292468b5a9..dedd55f16afb879ea966dc89d14d88ee15d9e83e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -81,6 +81,7 @@ cc_library( ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ]) + if_android([ diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 05fd9cd981f70b9f54b65a59a2e92c5405a80f08..2cf68c9cd8396987899b4f34f21b994b4722ead4 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -22,11 +22,16 @@ import os as _os import sys as _sys # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # API IMPORTS PLACEHOLDER from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg=( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation")) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4e3229bc709d25195392afc84382a61703782255..121de401cefb2b56b984944dde769f226590dc67 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -208,6 +208,7 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0..261519de3478c8b3e30d206a15944b5a686598e2 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); -REGISTER_OP("_XlaSendFromHost") - .Input("inputs: Tinputs") - .Input("dynamic_key: string") - .Attr("Tinputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - -REGISTER_OP("_XlaRecvAtHost") - .Input("dynamic_key: string") - .Output("outputs: Toutputs") - .Attr("Toutputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 109684be72a2d67d04ac9efda0b17650f6905752..f0c9d573451952a398dce190e102a33270a4d739 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -200,7 +200,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto serialized = absl::make_unique(size); TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 92229842bdbced6431fd5b3e158f275a41819728..eceb47f167f46784dc935a1d8b6fb4e5fe469367 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -102,7 +102,8 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - absl::make_unique(); + absl::make_unique( + backend->stream_executors()[device_ordinal]); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 28681bb8b03dbf97e8145972f9a04b5855fafdae..05b9c511866d3ca48ec3519bee8a4dbf6086f6ac 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -29,7 +29,10 @@ limitations under the License. namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator() {} +XlaDeviceAllocator::XlaDeviceAllocator( + stream_executor::StreamExecutor* stream_executor) + : stream_executor_(stream_executor) {} + XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } @@ -48,7 +51,21 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { delete XlaTensor::FromOpaquePointer(ptr); } -void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } +absl::optional XlaDeviceAllocator::GetStats() { + absl::optional se_stats = + stream_executor_->GetAllocatorStats(); + if (!se_stats) { + return absl::nullopt; + } + + tensorflow::AllocatorStats tf_stats; + tf_stats.num_allocs = se_stats->num_allocs; + tf_stats.bytes_in_use = se_stats->bytes_in_use; + tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; + tf_stats.largest_alloc_size = se_stats->largest_alloc_size; + tf_stats.bytes_limit = se_stats->bytes_limit; + return tf_stats; +} XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index e45db989fac720df6c3458c93a6b8dbb0919f930..1ce64ad323b4827adc2f4d48841315fbde43e532 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -34,14 +34,18 @@ namespace tensorflow { // empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(); + XlaDeviceAllocator(se::StreamExecutor* stream_executor); ~XlaDeviceAllocator() override; string Name() override; void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void GetStats(AllocatorStats* stats) override; + absl::optional GetStats() override; + + private: + // The stream executor of the device. + se::StreamExecutor* stream_executor_; }; // Helper class for managing data transfers between host and XLA devices. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index f201f62a78ce9d9599b2397a5f108e335469445a..09e04d22def9c39f45c2737c1d4a5e7787e3fdc0 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/data/generator_dataset_op.h" #include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" @@ -253,6 +254,15 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .HostMemory("string_handle"), \ data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ + data::OptionalNoneOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ + data::OptionalFromValueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ + data::OptionalHasValueOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ + data::OptionalGetValueOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 9b6ca4092c3177ac26503add13bce25d2c0bb820..7c1e0daf0b7b418530367cb80fbd18b93e8e5f5e 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -250,6 +250,29 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "self_adjoint_eig_op_test", + size = "medium", + srcs = ["self_adjoint_eig_op_test.py"], + # TODO(kuny): remove it after b/124377352 is fixed. + disabled_backends = [ + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", diff --git a/tensorflow/compiler/tests/plugin.bzl b/tensorflow/compiler/tests/plugin.bzl index fbc8781a3e59faecf985cde5114bf56a041c4be0..46a854d1459b7ea9d9fe3cf7689faee557c2cf84 100644 --- a/tensorflow/compiler/tests/plugin.bzl +++ b/tensorflow/compiler/tests/plugin.bzl @@ -18,13 +18,12 @@ # git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl plugins = { - #"example": { - # "device":"XLA_MY_DEVICE", - # "types":"DT_FLOAT,DT_HALF,DT_INT32", - # "tags":[], - # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], - # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], - # "deps":[], - #}, + #"example": { + # "device":"XLA_MY_DEVICE", + # "types":"DT_FLOAT,DT_HALF,DT_INT32", + # "tags":[], + # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], + # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], + # "deps":[], + #}, } - diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 693f8513bc54e30060a2e963abd504768535a50a..a9a87b8fb3104f8b9870c41e2aa28b0c48c12921 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -134,6 +134,12 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + def testRepeatedIndices(self): + indices = np.array([[0], [1], [0], [1]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([20, 22], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2])) + def testSimple2(self): indices = np.array([[1, 0], [1, 1]], dtype=np.int32) updates = np.array([11., 12.], dtype=np.float32) diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb5c82b22ea1d7400b54045edee0ca0782ce979 --- /dev/null +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.self_adjoint_eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _test(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + x_np = x_np + np.swapaxes(x_np, -1, -2) + n = shape[-1] + + e_np, _ = np.linalg.eigh(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + e, v = linalg_ops.self_adjoint_eig(x_tf) + e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np}) + + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6) + self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6) + + SIZES = [1, 2, 5, 10, 32] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSelfAdjointEig(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._test(dtype, batch_dims + (n, n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 47e0f384a4f1e46ccc35584aaff3a0aceff8a985..a380715301b08ce2186c97b678b7235b9121d178 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -102,7 +102,7 @@ class ListOpsTest(xla_test.XLATestCase): _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): - self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): @@ -136,6 +136,17 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 2.0]) + def testSetDoesNotUpdatePushIndex(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) + # SetItem should not change the push index. + l = list_ops.tensor_list_set_item(l, 1, 3.) + l = list_ops.tensor_list_push_back(l, 5.) + l = list_ops.tensor_list_push_back(l, 7.) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [5., 7.]) + def testGetSetReserved(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve( @@ -146,6 +157,25 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 0.0]) + def testSetStackReservedUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) + + def testPushInEmptyListWithUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) + l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) + # Pushing an element with a different shape should raise an error. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): + l = list_ops.tensor_list_push_back(l, 5.) + self.evaluate( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) + def testGetSetReservedNonScalar(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.tensor_list_reserve( diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 7466aea4c9b06c3e8fa7dfe5937288b5425f3e8b..63cad6a159c3a9b0da9e3bb86ff250dd29e45729 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -171,13 +171,11 @@ tf_cuda_library( name = "trt_resources", srcs = [ "utils/trt_int8_calibrator.cc", - "utils/trt_resource_manager.cc", "utils/trt_resources.cc", ], hdrs = [ "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", - "utils/trt_resource_manager.h", "utils/trt_resources.h", ], deps = [ @@ -266,7 +264,6 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -433,7 +430,7 @@ cc_library( copts = tf_copts(), deps = [ "//tensorflow/core:framework", - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", ], ) @@ -442,7 +439,7 @@ cc_library( srcs = ["utils/test_utils.cc"], hdrs = ["utils/test_utils.h"], deps = [ - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", "@com_googlesource_code_re2//:re2", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index d6080c02d435fc149f679ebe9c9bacc8d0a0c144..3f4b3732b0ddb7a36a985ad4b7950594fef8eb41 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" @@ -90,7 +89,7 @@ TrtCandidateSelector::TrtCandidateSelector( Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - static const std::set candidate_ops = { + static const auto* candidate_ops = new std::set{ "Abs", "Add", "AvgPool", @@ -106,6 +105,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "ExpandDims", "FusedBatchNorm", "FusedBatchNormV2", + "GatherV2", "Identity", "LeakyRelu", "Log", @@ -128,6 +128,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "Rsqrt", "Rsqrt", "Sigmoid", + "Slice", "Snapshot", "Softmax", "Sqrt", @@ -141,9 +142,9 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "Transpose", }; bool is_supported_op_type = - (candidate_ops.count(node->type_string()) || + (candidate_ops->count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); - static const std::set quantize_ops = { + static const auto* quantize_ops = new std::set{ "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", "FakeQuantWithMinMaxVars", @@ -153,7 +154,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // these ops to the relevant tensors. This happens regardless of the value of // use_calibration. if (precision_mode_ == TrtPrecisionMode::INT8 && - quantize_ops.count(node->type_string())) { + quantize_ops->count(node->type_string())) { is_supported_op_type = true; } // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc) @@ -190,55 +191,6 @@ tensorflow::Status BuildNodeMap( } // namespace -// Function to get calibration from ResourceMgr and put them into nodedef. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, - bool is_dyn_op) { - LOG(INFO) << "Starting Calib Conversion"; - *infer_graph = graph_def; - auto trt_rm = TRTResourceManager::instance(); - auto calib_rm = trt_rm->getManager("TRTCalibration"); - int num_nodes = infer_graph->node_size(); - if (!is_dyn_op) { - LOG(WARNING) << "Construction of static int8 engine is not implemented " - "yet!. Dynamic engine will be constructed"; - } - for (int i = 0; i < num_nodes; ++i) { - auto n = infer_graph->mutable_node(i); - if (n->op() == "TRTEngineOp") { - VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); - TRTCalibrationResource* cres = nullptr; - auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); - if (!status.ok()) { - LOG(ERROR) << "Could not get Calibration information. Did you run with " - "calibration data?"; - return tensorflow::errors::FailedPrecondition( - "Need to run graph with calibration data first!"); - } - tensorflow::core::ScopedUnref calib_sc(cres); - if (cres->calibrator_) { - cres->calibrator_->waitAndSetDone(); - cres->thr_->join(); - const auto& calibration_table = - cres->calibrator_->getCalibrationTableAsString(); - if (calibration_table.empty()) { - LOG(ERROR) << "Calibration table is empty"; - return tensorflow::errors::Unknown( - "Calibration table is missing. This shouldn't have happened!"); - } - n->mutable_attr()->at("calibration_data").set_s(calibration_table); - } else { - LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; - return tensorflow::errors::Unknown( - "Can't get TRTCalibrator from resource manager!"); - } - TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); - } - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 95cf0227dcf84396b9de52194ae3a750f4acca66..80f68d36a3ab894e97586687ee9ab93dddc73c50 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -85,12 +85,6 @@ struct ConversionParams { std::vector cached_engine_batches; // list of cached engines }; -// This method extracts calibration information from the resource managers -// and puts them in to engine nodedefs. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, - bool is_dyn_op); - // - max_batch_size: maximum batch size which can be used for inference for // optimization targets inference run with max batch size. // - max_workspace_size_bytes: The upper bound of memory allowance for engine diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 0d5b9851f79e97279ec0680986efe13e56dbd7c5..de9c1b69f4020353064f25cdf5c652ad44a1cae3 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" @@ -379,6 +378,32 @@ tensorflow::Status CreateBroadcastableScalarConstant( return Status::OK(); } +// Convert an axis from TF format to TRT format while validating. TF format +// includes the batch dimension, while TRT does not. TF can also use negative +// indices. +// TODO(tmorris): Use this method in more ops. +tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims, + absl::string_view node_name, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + 1; + // Check bounds. + if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { + return tensorflow::errors::InvalidArgument( + "Axis value of ", tf_axis, " is out of bounds, must be in range [", + -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name); + } + // Make negative axis positive. + if (tf_axis < 0) tf_axis += tf_nb_dims; + // Don't allow axis to be the batch dimension. + if (tf_axis == 0) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow manipulation of the batch dimension, at ", + node_name); + } + // Remove batch dimension. + *trt_axis = tf_axis - 1; + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -392,6 +417,15 @@ inline bool DimsEqual(const nvinfer1::Dims& dim_l, return true; } +bool AllLengthsEqual(const std::vector>& inputs) { + if (inputs.size() == 0) return true; + int length = inputs.at(0).size(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs.at(i).size() != length) return false; + } + return true; +} + inline nvinfer1::Dims GetTrtDimsForTensor(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -529,6 +563,16 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return 0; } #endif +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; @@ -2151,100 +2195,73 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { return tensorflow::Status::OK(); } -// Gets the bounds (start or end) from the weights of a StridedSlice op. -tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, - const TRT_ShapedWeights& bound_weights, - int mask, bool begin, string node_name, - std::vector* output_bound) { - const string bound_name = (begin) ? "begin" : "end"; - const int* weights_ptr = static_cast(bound_weights.GetValues()); - *output_bound = - std::vector(weights_ptr, weights_ptr + bound_weights.count()); - if (output_bound->size() != input_dims.size()) { - return tensorflow::errors::InvalidArgument( - "StridedSlice \"", bound_name, "\" specified ", - std::to_string(output_bound->size()), " dimensions, but input rank is ", - std::to_string(input_dims.size()), ", at ", node_name); - } - for (int i = 0; i < output_bound->size(); i++) { - if ((1 << i) & mask) { - // Apply mask. - (*output_bound)[i] = (begin) ? 0 : input_dims[i]; - // Masked bound will always result in a valid, non-negative bound, so we - // don't need the following checks. For the common case of using masks on - // a undefined batch dim (-1), we specifically don't want to do the - // following checks because they will erroneously detect an out of range - // bound or try to correct the negative value. - continue; - } - // Make sure bound is valid. - if (((*output_bound)[i] < -input_dims[i]) || - ((*output_bound)[i] > input_dims[i])) { +tensorflow::Status ConvertStridedSliceHelper(OpConverterParams* params, + const TRT_TensorOrWeights& input, + std::vector begin, + std::vector size, + const std::vector& stride) { + const auto& node_def = params->node_def; + // Get input dims. + nvinfer1::Dims dims = input.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), -1); + // Check bounds. + for (int i = 1; i < input_dims.size(); i++) { + if (begin[i] < 0 || begin[i] > input_dims[i]) { return tensorflow::errors::InvalidArgument( - bound_name, " value of ", std::to_string((*output_bound)[i]), - " for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at ", - node_name); + "\"begin\" for dimension ", std::to_string(i), " in ", node_def.op(), + " is out of range, at ", node_def.name()); } - // Convert negative values to their positive equivalent. - if ((*output_bound)[i] < 0) { - (*output_bound)[i] += input_dims[i]; + const int end = begin[i] + size[i]; + if (end < 0 || end > input_dims[i]) { + return tensorflow::errors::InvalidArgument( + "\"begin\" + \"size\" for dimension ", std::to_string(i), " in ", + node_def.op(), " is out of range, at ", node_def.name()); + } + if (size[i] <= 0) { + return tensorflow::errors::InvalidArgument( + "\"size\" cannot be negative or zero for ", node_def.op(), ", at ", + node_def.name()); } } - return tensorflow::Status::OK(); -} +// TRT 5.1 adds a slice layer. For older versions, we attempt to use the +// padding layer with negative padding. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Use ISliceLayer. + nvinfer1::Dims begin_dims, size_dims, stride_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims, + /*ignore_first_dim=*/true)); + if (params->validation_only) return Status::OK(); -tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { - const auto& inputs = params->inputs; - const auto& node_def = params->node_def; - TF_RETURN_IF_ERROR(CheckInputsWeights( - *params, - {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); - // Get input dims. - nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - if (inputs.at(0).is_tensor()) { - // Temporarily add batch dimension so that indexes line up properly. - input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); - } - if (input_dims.size() > 4) { - return tensorflow::errors::Unimplemented( - "StridedSlice is not implemented for tensors with rank > 4, at ", - node_def.name()); - } - TFAttrs attrs(node_def); - // Get begin and end bounds per axis. - std::vector begin, end; - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), - attrs.get("begin_mask"), true, - node_def.name(), &begin)); - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), - attrs.get("end_mask"), false, - node_def.name(), &end)); - // Get strides per axis (must all be 1). - TRT_ShapedWeights stride_weights = inputs.at(3).weights(); - const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); - std::vector strides(stride_weights_ptr, - stride_weights_ptr + stride_weights.count()); - for (int x : strides) { + nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( + *const_cast(input.tensor()), begin_dims, size_dims, + stride_dims); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +#else + // Use IPaddingLayer. + // Strides must be 1 in this case. + for (int x : stride) { if (x != 1) { return tensorflow::errors::Unimplemented( - "StridedSlice is only implemented for stride of 1, at ", + "Strides other than 1 are not supported with this version of TRT, " + "at ", node_def.name()); } } - // Unsupported mask options. - for (const string& attr : - {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { - int attr_val = attrs.get(attr); - if (attr_val != 0) { - return tensorflow::errors::Unimplemented( - attr, " is not supported for StridedSlice, at ", node_def.name()); - } + // Rank must be 2, 3 or 4. + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented(node_def.op(), + " for tensors with rank > 4 is " + "not supported in this version of " + "TRT, at ", + node_def.name()); } - - nvinfer1::ITensor* tensor = - const_cast(inputs.at(0).tensor()); // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. const bool need_reshape = (input_dims.size() != 4); int reshape_dims_added = 0; @@ -2254,7 +2271,7 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { while (input_dims.size() < 4) { input_dims.insert(input_dims.begin() + 1, 1); begin.insert(begin.begin() + 1, 0); - end.insert(end.begin() + 1, 1); + size.insert(size.begin() + 1, 1); reshape_dims_added++; } TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, @@ -2262,23 +2279,22 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { } // Find dimensions which need to be sliced. std::vector pad_dims; - for (int i = 0; i < input_dims.size(); i++) { - if ((begin[i] != 0) || (end[i] != input_dims[i])) { - if (i == 0) { - return tensorflow::errors::Unimplemented( - "StridedSlice can't modify batch dim, at ", node_def.name()); - } else if ((end[i] - begin[i]) < 0) { - return tensorflow::errors::InvalidArgument( - "New size of sliced dimension is negative, at ", node_def.name()); - } + for (int i = 1; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) { pad_dims.push_back(i); } } if (pad_dims.empty()) { - // No dimensions are changed. We could create a padding layer anyway with - // values of 0. + // No dimensions are changed, so this is a no-op. We could just return the + // input without creating a new layer. TRT will crash if an empty engine + // with no layers is attempted to be created, so we add a no-op shuffle to + // prevent our unit tests from breaking. + // TODO(tmorris): Allow empty engines in the unit tests and return the input + // as output here. if (params->validation_only) return Status::OK(); - params->outputs->push_back(inputs.at(0)); + nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( + *const_cast(input.tensor())); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return tensorflow::Status::OK(); } else if (pad_dims.size() == 1) { // Only one dim is modified but we have to have 2, mark a second dim which @@ -2291,16 +2307,19 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { } } else if (pad_dims.size() > 2) { return tensorflow::errors::Unimplemented( - "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + node_def.op(), + " can only modify up to 2 dimensions in this version of TRT, at ", + node_def.name()); } std::sort(pad_dims.begin(), pad_dims.end()); // Convert to pre/post padding values. Since TRT does not have a StridedSlice - // or Slice layer, we instead create an IPaddingLayer with negative padding. + // or Slice layer prior to 5.1, we instead create an IPaddingLayer with + // negative padding. nvinfer1::DimsHW pre_padding, post_padding; for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; pre_padding.d[i] = -begin[axis]; - post_padding.d[i] = end[axis] - input_dims[axis]; + post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis]; } // IPaddingLayer will always apply the padding to dims 2,3 (input format is @@ -2320,10 +2339,11 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Start conversion. + nvinfer1::ITensor* tensor = const_cast(input.tensor()); if (need_reshape) { const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), reshape_dims, &output_tensor)); + input, reshape_dims, &output_tensor)); tensor = const_cast(output_tensor); } if (need_transpose) { @@ -2332,7 +2352,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { tensor, transpose_order, &output_tensor)); tensor = const_cast(output_tensor); } - // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *const_cast(tensor), pre_padding, post_padding); @@ -2340,7 +2359,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); - // Restore transpose if (need_transpose) { const nvinfer1::ITensor* output_tensor = nullptr; @@ -2353,7 +2371,7 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { // Calculate output dimensions for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; - input_dims[axis] = end[axis] - begin[axis]; + input_dims[axis] = size[axis]; } // Remove added 1 dimensions for (int i = 0; i < reshape_dims_added; i++) { @@ -2377,6 +2395,135 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { params->outputs->push_back( TRT_TensorOrWeights(const_cast(tensor))); return tensorflow::Status::OK(); +#endif +} + +tensorflow::Status ConvertSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"begin", true}, {"size", true}})); + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector size = inputs.at(2).weights().ToVector(); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + if (!AllLengthsEqual({input_dims, begin, size})) { + return tensorflow::errors::InvalidArgument( + "Length of begin and size arguments must equal rank of input for " + "Slice, at ", + node_def.name()); + } + // Check that batch dimension is unmodified. + const bool begin_is_modified = begin[0] != 0; + // If size[0]s is not -1, we can only know if the batch dimension is + // unmodified when the batch size is defined. When the batch size is + // undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool size_is_modified = + size[0] != -1 && (!batch_size_is_defined || + (batch_size_is_defined && size[0] != input_dims[0])); + if (begin_is_modified || size_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Size of -1 signifies to take all remaining elements. + for (int i = 1; i < input_dims.size(); i++) { + if (size[i] == -1) { + size[i] = input_dims[i] - begin[i]; + } + } + // Stride is 1 for all dims. + std::vector stride(begin.size(), 1); + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + // Get begin and end bounds per axis. + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector end = inputs.at(2).weights().ToVector(); + std::vector stride = inputs.at(3).weights().ToVector(); + if (!AllLengthsEqual({input_dims, begin, end, stride})) { + return tensorflow::errors::InvalidArgument( + "Length of begin, end, and stride arguments must equal rank of input " + "for StridedSlice, at ", + node_def.name()); + } + // Unsupported mask options. + TFAttrs attrs(node_def); + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + const int begin_mask = attrs.get("begin_mask"); + const int end_mask = attrs.get("end_mask"); + // Check that batch dimension is unmodified. + const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0; + const bool stride_is_modified = stride[0] != 1; + // If the batch size is -1 and the end mask is not set, we can only know if + // the batch dimension is unmodified when the batch size is defined. When the + // batch size is undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool end_is_modified = + !(end_mask & 1) && (!batch_size_is_defined || + (batch_size_is_defined && end[0] != input_dims[0])); + if (begin_is_modified || stride_is_modified || end_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Standarize begin and end bounds by applying masks, making negative values + // positive, and correcting out of bounds ranges (StridedSlice does this + // silently). + for (int i = 1; i < input_dims.size(); i++) { + // Begin + if ((1 << i) & begin_mask) { + begin[i] = 0; + } else if (begin[i] < 0) { + begin[i] += input_dims[i]; + } + begin[i] = std::max(0, std::min(begin[i], input_dims[i])); + // End + if ((1 << i) & end_mask) { + end[i] = input_dims[i]; + } else if (end[i] < 0) { + end[i] += input_dims[i]; + } + end[i] = std::max(0, std::min(end[i], input_dims[i])); + } + // Negative or zero strides currently not supported. + for (int i = 0; i < input_dims.size(); i++) { + if (stride[i] <= 0) { + return tensorflow::errors::Unimplemented( + "Negative or zero stride values are not supported for StridedSlice, " + "at ", + node_def.name()); + } + } + // TRT Slice layer uses (begin, size) instead of (begin, end) + std::vector size(input_dims.size()); + for (int i = 0; i < input_dims.size(); i++) { + // Divide by stride (round up) + size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i]; + } + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); } tensorflow::Status ConvertConv2D(OpConverterParams* params) { @@ -3413,6 +3560,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertGather(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"params", false}, {"indices", false}, {"axis", true}})); + absl::Span axis = inputs.at(2).weights().GetSpan(); + if (axis.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Axis for GatherV2 must be a scalar, at ", node_def.name()); + } + int trt_axis = 0; + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, + node_def.name(), &trt_axis)); + if (params->validation_only) return Status::OK(); + + nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( + *const_cast(inputs.at(0).tensor()), + *const_cast(inputs.at(1).tensor()), trt_axis); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights tensor_input, TRT_ShapedWeights weights_raw, @@ -3643,11 +3813,13 @@ static void RegisterValidatableOpConverters( (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["GatherV2"] = ConvertGather; (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; (*registration)["Reshape"] = ConvertReshape; + (*registration)["Slice"] = ConvertSlice; (*registration)["Square"] = ConvertSquare; (*registration)["Squeeze"] = ConvertSqueeze; (*registration)["StridedSlice"] = ConvertStridedSlice; @@ -3721,8 +3893,12 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setMaxWorkspaceSize(max_workspace_size_bytes); builder->setGpuAllocator(allocator); if (precision_mode == TrtPrecisionMode::FP16) { - builder->setHalf2Mode(true); + builder->setFp16Mode(true); } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder->setFp16Mode(true); builder->setInt8Mode(true); if (use_calibration) { builder->setInt8Calibrator(calibrator); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d1e30eb848bd6ab62719ca6da561d14b05d8537d..45edafd2be7e9be9f6752940b712e0d96d67550c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -190,6 +190,17 @@ class TRT_ShapedWeights { string DebugString() const; + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), count()); + } + + template + std::vector ToVector() const { + auto span = GetSpan(); + return std::vector(span.data(), span.data() + span.size()); + } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. tensorflow::DataType type_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index bb1341ada3766ea322029ec2904e4ae2c6f5544d..f29d0b94d97a259ce4ced51eba1d0a3aa2b33536 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -2606,46 +2606,62 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("strides", {4}, {1, 1, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "StridedSlice can't modify batch dim, at my_strided_slice"); + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); } { - // Stride is not 1, should fail. + // Dynamic batch size without end_mask, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); AddTestWeights("begin", {4}, {0, 0, 0, 0}); AddTestWeights("end", {4}, {1, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 2, -1, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "StridedSlice is only implemented for stride of " - "1, at my_strided_slice"); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); } { - // Begin out of bounds, should fail. + // Dynamic batch size but using end_mask, ok. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*end_mask=*/1); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 2}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def); + } +// TRT 5.1+ supports strides +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + { + // Negative strides, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("begin", {4}, {1, 2, 3, 4}); - AddTestWeights("end", {4}, {0, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "begin value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, -1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Negative or zero stride values are not " + "supported for StridedSlice, at " + "my_strided_slice"); } +#else { - // End out of bounds, should fail. + // Stride is not 1, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("begin", {4}, {0, 0, 0, 0}); - AddTestWeights("end", {4}, {1, 2, 3, 4}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "end value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, 1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Strides other than 1 are not supported with " + "this version of TRT, at my_strided_slice"); } +#endif { // Size of sliced dim is negative, should fail. Reset(); @@ -2654,19 +2670,20 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("begin", {4}, {0, 0, 2, 0}); AddTestWeights("end", {4}, {1, 1, 0, 3}); AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "New size of sliced dimension is negative, at my_strided_slice"); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "\"size\" cannot be negative or zero for " + "StridedSlice, at my_strided_slice"); } struct TestParams { std::vector input_dims; - std::vector expected_output_dims; std::vector begin; std::vector end; + std::vector strides; int begin_mask; int end_mask; - std::vector expected_output; + std::vector expected_output_dims; + std::vector expected_output; }; auto get_mask = [](const std::vector& mask) { @@ -2677,105 +2694,159 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { return result; }; + // Same input is used for all tests. + const std::vector ok_input = {1, 2, 3, 4, 5, 6}; + +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + const int kStridedSliceOKCases = 23; +#else + const int kStridedSliceOKCases = 19; +#endif // Ok. - const int kStridedSliceOKCases = 18; TestParams ok_params[kStridedSliceOKCases] = { - // 2D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0, 0}), - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0, 0}), - /*expected_output=*/{5, 6}}, - // 2D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), - /*expected_output=*/{5, 6}}, - // 2D Crop, with reshape. - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0}), - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1}), - /*expected_output=*/{5, 6}}, - // 1D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 0}), - /*expected_output=*/{1, 2, 4, 5}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, - /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1, 1}), - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with reshape. - TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, - /*begin=*/{0, 0}, /*end=*/{0, 3}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, - /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0}), - /*expected_output=*/{3, 4, 5}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output=*/{3, 4, 5}}, - // Negative axis. - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, - /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output=*/{1, 2, 3, 4, 5}}, + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0}), + /*expected_output_dims=*/{1, 3}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{5, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + // Clamp out of bounds begin and end. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9}, + /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}}, +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Strides + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{2, 4, 6}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{2}, + /*expected_output=*/{3, 6}}, +#endif }; for (int i = 0; i < kStridedSliceOKCases; i++) { @@ -2788,16 +2859,18 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { ok_params[i].begin); AddTestWeights("end", {static_cast(ok_params[i].end.size())}, ok_params[i].end); - std::vector strides(ok_params[i].input_dims.size(), 1); - AddTestWeights("strides", {static_cast(strides.size())}, - strides); + AddTestWeights("strides", + {static_cast(ok_params[i].strides.size())}, + ok_params[i].strides); RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + const DataVec input_data{{"input", test::AsTensor(ok_input)}}; DataVec output_data{ {"my_strided_slice", ConstructTensor(ok_params[i].expected_output.size())}}; @@ -2807,6 +2880,148 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { } } +TEST_F(OpConverterTest, ConvertSlice) { + // Get nodedef for Slice layer. + auto get_slice_nodedef = []() -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32); + auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size); + return slice.operation.node()->def(); + }; + + { + // Begin is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, -1, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Begin is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 3, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Size is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, -2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at " + "my_slice"); + } + { + // Size is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 3, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at " + "my_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size with size[0] not -1, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size but using size[0] of -1, ok. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {-1, 1, 2, 2}); + RunValidationAndConversion(node_def); + } + + struct TestParams { + std::vector input_dims; + std::vector begin; + std::vector size; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kSliceOKCases = 5; + TestParams ok_params[kSliceOKCases] = { + TestParams{{1, 2, 3}, + {0, 0, 0, 0}, + {-1, -1, -1, -1}, + {1, 2, 3}, + {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}}, + TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}}, + TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}}, + }; + + for (int i = 0; i < kSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("size", {static_cast(ok_params[i].size.size())}, + ok_params[i].size); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_slice", ConstructTensor( + ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + TEST_F(OpConverterTest, ConvertConv2D) { { // Input list is empty, should fail. @@ -3129,6 +3344,126 @@ TEST_F(OpConverterTest, ConvertTopK) { } } +template +void TestConvertGather(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), dtype); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + + struct TestParams { + std::vector params_dims; + std::vector indices_dims; + std::vector indices; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Input is the same {1, 2, 3, 4, 5, 6} for all cases. + const int kGatherOKCases = 5; + TestParams ok_params[kGatherOKCases] = { + // Vector indices (output is rank(params)). + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}}, + TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}}, + // Higher rank indices (output is rank(params) + rank(indices) - 1). + TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}}, + }; + + // Ok. + for (int i = 0; i < kGatherOKCases; i++) { + test->Reset(); + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + // Create input in CType and convert expected output to CType. + std::vector inputs = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; + std::vector converted_expected_output( + ok_params[i].expected_output.begin(), + ok_params[i].expected_output.end()); + + const DataVec input_data{ + {"params", test::AsTensor(inputs)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec output_data{ + {"my_gather", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(converted_expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertGather) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "GatherV2 got 0 inputs but expected 3, at my_gather"); + } + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for GatherV2 must be a constant, at my_gather"); + } + { + // Axis is out of bounds, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_gather"); + } + { + // Axis is batch dimension, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_gather"); + } + + Reset(); + TestConvertGather(this); + TestConvertGather(this); + TestConvertGather(this); +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index e3b31d736eb89b079410bba34b26d259ef2c2527..f6d387c59cd04aa5c7ccad610290b7b1f1d2b11f 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -295,27 +294,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return this->AllocateCalibrationResources(ctx, cr); }})); tensorflow::core::ScopedUnref calib_sc(calib_res); - // TODO(aaroey): here we also add the resource to the ResourceMgr singleton. - // This is needed before we migrate all uses of calib_graph_to_infer_graph() - // to the new calibration workflow. After that we'll remove this block. - { - auto deprecated_rm = - TRTResourceManager::instance()->getManager("TRTCalibration"); - TRTCalibrationResource* copied_resource = nullptr; - // Check whether the resource exists, and create it if not. - if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource) - .ok()) { - // Do nothing if the resource exists. - copied_resource->Unref(); - } else { - copied_resource = calib_res; - // Increase the refcount by 1 then transfer the ownership of that refcount - // to the ResourceMgr singleton. - copied_resource->Ref(); - OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator", - copied_resource)); - } - } int num_inputs = ctx->num_inputs(); // Pass input data to calibrator std::unordered_map input_data; diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 3bcca99afbff8b84d2dd628ae9211ee94e86af2a..dd3c09d7e42358a1f9e6cc13be6198de58e38963 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include "re2/re2.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h index bcd628b62f0320f7ce9dfe6240316d876f1d5a20..d85875991b79014c4f173d3157ed02e6c96f045c 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -16,8 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc deleted file mode 100644 index 0a72a88bc740101bcbadb40bfe106a5b8d284bbf..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace tensorrt { - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::instance() { - static std::shared_ptr instance_(new TRTResourceManager); - return instance_; -} - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { - // mutex is held for lookup only. Most instantiations where mutex will be held - // longer will be during op creation and should be ok. - tensorflow::mutex_lock lock(map_mutex_); - auto s = managers_.find(op_name); - if (s == managers_.end()) { - auto it = managers_.emplace( - op_name, std::make_shared(op_name)); - VLOG(1) << "Returning a new manager " << op_name; - return it.first->second; - } - VLOG(1) << "Returning old manager " << op_name; - return s->second; -} - -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h deleted file mode 100644 index 03879ffff2fa724b05cb1919753e4aaa99e2e702..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#include - -#include -#include -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTResourceManager { - TRTResourceManager() = default; - - public: - static std::shared_ptr instance(); - // returns a manager for given op, if it doesn't exists it creates one - std::shared_ptr getManager(const string& op_name); - - private: - std::unordered_map> - managers_; - tensorflow::mutex map_mutex_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a1a9435c19160cfb8130253a9fa756af423165c..7d9e7b9fc1f7ea83d6aa982afb5df097b0bdbf77 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -24,7 +24,7 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -60,6 +60,14 @@ xla_proto_library( ], ) +xla_py_proto_library( + name = "tf2xla_py", + has_services = False, + api_version = 2, + visibility = ["//visibility:public"], + deps = [":tf2xla_proto"], +) + xla_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], @@ -283,6 +291,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b3f050c52b3a71067a1cc7aa0cd18905e35e4f1c..343568b2392595a2347bde41f0a2e2559fb1de19 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -107,11 +107,13 @@ tf_kernel_library( "xla_pad_op.cc", "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", + "xla_self_adjoint_eig_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], + tags = ["optonly"], deps = [ ":conv_op_helpers", ":if_op", @@ -143,6 +145,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:quantize", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index b96d45316f626e678a64392a4315979eeeb6e83c..d19d48e5dd95962fe4a4e4026eaf6b06b7898564 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -134,14 +135,15 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Unlike the bilinear kernel, which is triangular, the nearest neighbor @@ -153,11 +155,12 @@ xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { // to the right (because an existing non TPU kernel // for nearest neighbor resize already chose to default to the right, // so we want to be consistent). -xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1, 0.0f); std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Kernels with more than 16 spatial elements are considered intense and the @@ -165,42 +168,66 @@ xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = xla::Broadcast( - xla::Zero(builder, xla::F32), - {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); + std::vector depthwise_kernel_sizes = { + (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}; + auto depthwise_kernel = + xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]), + depthwise_kernel_sizes, /*broadcast_dimensions=*/{1}); - return xla::Mul( - xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[1]), - /*broadcast_dimensions=*/{1}), - make_kernel_func(builder, kernel_size[0]), - /*broadcast_dimensions=*/{0}); + return xla::Mul(depthwise_kernel, + make_kernel_func(builder, type, kernel_size[0]), + /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, int64 dim, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = - xla::Broadcast(xla::Zero(builder, xla::F32), - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[dim]), - /*broadcast_dimensions=*/{dim}); + std::vector depthwise_kernel_sizes = { + dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}; + return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]), + depthwise_kernel_sizes, + /*broadcast_dimensions=*/{dim}); +} + +xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, + const xla::XlaOp& input, + int32 spatial_dimensions_offset, + absl::Span in_size, + absl::Span out_size) { + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + auto broadcast_shape_or_status = builder->GetShape(input); + if (!broadcast_shape_or_status.ok()) { + return builder->ReportError(broadcast_shape_or_status.status()); + } + xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie(); + for (int32 i = 0; i < in_size.size(); ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + broadcast_shape.set_dimensions(spatial_dimensions_offset + i, + out_size[i]); + } + } + return xla::BroadcastInDim(input, broadcast_shape.dimensions(), + /*broadcast_dimensions=*/{0, 1, 2, 3}); } xla::XlaOp ResizeUsingDilationAndConvolution( - xla::XlaBuilder* builder, const xla::XlaOp& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels, const bool align_corners, - bool is_kernel_bilinear) { + xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span out_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: @@ -287,7 +314,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, @@ -299,7 +326,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -308,7 +335,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -320,19 +347,14 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && out_size[i] > 1) { - output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); - } - } - return output; + return BroadcastSpatialDimensions( + builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size); } xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( - xla::XlaBuilder* builder, const xla::XlaOp& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels, + xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span grad_size, const int64 channels, const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -353,19 +375,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = - xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), - /*broadcast_dimensions=*/{i}); - } - } + kernel = BroadcastSpatialDimensions( + builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size); output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, @@ -377,22 +394,22 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a // size == 1 dimension to a size > 1 dimension. This has the effect of // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { - kernel0 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + kernel0 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, {1}, + {grad_size[0]}); } if (in_size[1] == 1 && grad_size[1] > 1) { - kernel1 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + kernel1 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, + in_size, grad_size); } output = xla::ConvGeneralDilated( @@ -423,7 +440,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( } } if (pad_output) { - output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); + output = xla::Pad(output, xla::Zero(builder, type), padding); } return output; } @@ -458,6 +475,7 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, const int num_spatial_dims = 2; xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -475,8 +493,11 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } - // Output is always type float. - input = xla::ConvertElementType(input, xla::F32); + // Output is always type float if 'is_kernel_bilinear' is true. + if (is_kernel_bilinear) { + input = xla::ConvertElementType(input, xla::F32); + input_type = xla::F32; + } // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -504,19 +525,19 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels, - align_corners_, is_kernel_bilinear); + b, input, input_type, num_spatial_dims, in_size, next_out_size, + channels, align_corners_, is_kernel_bilinear); input = output; in_size = next_out_size; } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } @@ -631,19 +652,19 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_, true); + b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size, + channels, align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index a95e7adacf194ba6eb33cbeb56abe1a5a2479337..a1c18bed3f94008af8038f32324c79aa5b2abded 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -110,10 +110,16 @@ class ScatterNdOp : public XlaOpKernel { auto updates = context->Input(1); auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + /*indices_are_vectors=*/true, /*combiner=*/Combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.ValueOrDie()); } + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } }; REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 65020012283d9c5f62e5e2fd11fc2bf1110e019a..8958a48bc79dce91c41ab7d0a5fc0fbb401112ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -69,6 +71,43 @@ class TensorListLengthOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); +// Creates an empty list with size (leading_dim, *element_shape) if +// element_shape is known at compile time. Otherwise creates one with size +// (leading_dim, 0) which gets initialized later in `GetInitializedList`. +Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, + int64 leading_dim, DataType dtype, xla::XlaOp* list) { + TensorShape list_shape; + list_shape.AddDim(leading_dim); + xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); + TF_ASSIGN_OR_RETURN( + bool is_element_shape_compile_time_const, + element_shape_handle.builder()->IsConstant(element_shape_handle)); + PartialTensorShape partial_element_shape; + if (is_element_shape_compile_time_const) { + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( + element_shape_index, &partial_element_shape)); + } + if (is_element_shape_compile_time_const && + partial_element_shape.IsFullyDefined()) { + TensorShape element_shape; + partial_element_shape.AsTensorShape(&element_shape); + list_shape.AppendShape(element_shape); + } else { + // If element_shape is not a compile time constant or if it is not fully + // defined we will have to wait for the first write call to fully allocate + // the array. + // TODO(srbs): We are using element_shape of [0] as a proxy to denote an + // uninitialized list. A better implementation may be to represent the + // list as a 3-tuple containining an explicit "initialized" flag. However, + // we would still need to create a dummy tensor for the first tuple + // element. + list_shape.AddDim(0); + } + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + return Status::OK(); +} + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -76,20 +115,15 @@ class TensorListReserveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - TensorShape tensor_shape; - tensor_shape.AddDim(num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, num_elements)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, num_elements)})); } private: @@ -110,8 +144,6 @@ class EmptyTensorListOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); OP_REQUIRES( @@ -119,15 +151,13 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - TensorShape tensor_shape; - tensor_shape.AddDim(max_num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, + CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, 0)})); } private: @@ -274,6 +304,36 @@ REGISTER_XLA_OP( Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), TensorListFromTensorOp); +// Returns the 0'th element of `tuple` containing the list tensor if it has been +// initialized already else creates one lazily. This allows lazy initialization +// of the list on the first call to SetItem or PushBack. +Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, + const TensorShape& element_shape, DataType dtype, + xla::XlaOp* list) { + *list = xla::GetTupleElement(tuple, 0); + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); + int64 leading_dim = list_shape.dim_size(0); + TensorShape list_element_shape = list_shape; + list_element_shape.RemoveDim(0); + // This checks for the lazy initialization contract set by CreateEmptyList. + // In TensorListReserve if the element_shape is not known at compile time, + // it creates a list with shape [leading_dim, 0]. + if (element_shape != list_element_shape) { + if (list_element_shape.num_elements() != 0) { + return errors::InvalidArgument( + "Invalid shape of value in TensorListSetItem. Expected: ", + list_element_shape.DebugString(), + " Actual: ", element_shape.DebugString()); + } + list_shape = element_shape; + list_shape.InsertDim(0, leading_dim); + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + } + return Status::OK(); +} + class TensorListSetItemOp : public XlaOpKernel { public: explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -285,7 +345,9 @@ class TensorListSetItemOp : public XlaOpKernel { xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(2); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); + xla::XlaOp index = ctx->Input(1); xla::XlaOp value = ctx->Input(2); @@ -299,8 +361,8 @@ class TensorListSetItemOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), - index + xla::ConstantR0(b, 1)})); + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), + xla::GetTupleElement(tl, 1)})); } private: @@ -319,11 +381,14 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp tl = ctx->Input(0); + xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); - xla::XlaOp index = xla::GetTupleElement(tl, 1); + xla::XlaOp list; + OP_REQUIRES_OK( + ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); + + xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -336,7 +401,7 @@ class TensorListPushBackOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), index + xla::ConstantR0(b, 1)})); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..233ac8e7b455403f8ee65b95b1403ecefdb92c6b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { +namespace { + +class XlaSelfAdjointEigOp : public XlaOpKernel { + public: + explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = + xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } + + private: + bool lower_; + int32 max_iter_; + float epsilon_; +}; + +class SelfAdjointEigV2Op : public XlaOpKernel { + public: + explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(n); + auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } +}; + +REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), + XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), + SelfAdjointEigV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index af641131ed76a8d6a7291c360302fa17c94af014..ccd58071d350e605e0e1f0c2b43643a400e32c2c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor rhs_output: the broadcasted RHS tensor )doc"); +REGISTER_OP("XlaSelfAdjointEig") + .Input("a: T") + .Attr("lower: bool") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Output("w: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + +a: the input tensor. + +lower: a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +w: The eigenvalues in ascending order, each repeated according to its + multiplicity. +v: The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 345193c936a885e5a9e468979c4b73b5b0c9e5c2..de4710d03a3e69afb04aa68e37961698f0e3a300 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index ddb284966eeb97cc7c9d3ed77fb313e567975e59..5bd0277c051711f2677b90a2679662899521e94a 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -60,8 +60,6 @@ class XlaCompilationAllocator : public Allocator { // buffers, so they get ids to track. bool ShouldAllocateEmptyTensors() override { return true; } - void GetStats(AllocatorStats* stats) override { stats->Clear(); } - private: // Don't run any constructors or destructors for complex objects, // since there is no backing store for the tensor to run them diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0833264523770dc43c6a784f8b3d731485f38e53..3221ec5b727de1f792cd61b792ee917588d56cf9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -185,9 +185,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of which retvals have layout to update. The first element is - // the output index, second element is the new layout. - std::vector> retval_to_update_layout; + // Keeps track of the layout of each retval. If a retval is not in this list, + // a descending layout is used. The first element is the output index, second + // element is the new layout. + std::vector> retval_index_and_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -216,7 +217,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + retval_index_and_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); @@ -289,6 +290,11 @@ Status BuildComputation( // Ensures the correct sharding is applied to the output. handle = identity_op(handle); + // Set layout of the retval to device representation layout. + if (resource->representation_shape().has_value()) { + retval_index_and_layout.emplace_back( + elems.size(), resource->representation_shape()->layout()); + } elems.push_back(handle); } } @@ -318,15 +324,15 @@ Status BuildComputation( computation->GetProgramShape()); *output_shape = program_shape.result(); // Update the output layout to the layout of retval. - for (auto& update : retval_to_update_layout) { + for (auto& index_and_layout : retval_index_and_layout) { if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = update.second; + *output_shape->mutable_layout() = index_and_layout.second; continue; } - xla::Shape* output_sub_shape = - xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); - *output_sub_shape->mutable_layout() = update.second; + xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( + output_shape, {index_and_layout.first}); + *output_sub_shape->mutable_layout() = index_and_layout.second; } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 492010f7317d32a8a620147cd2cd9356d4f13fde..b31137867d738944eaaa73e142ad8538ec6b854a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } +// Tests that the compiler can correctly propagate the layout assigned by +// shape_representation_fn_ to return types. +TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + options.shape_representation_fn = + [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + +// The layout of resource variable shouldn't change after transpose +TEST_F(XlaCompilerTest, TransposeVariables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto transposed_read = ops::Transpose(scope, read, {1, 0}); + auto reshape = ops::Reshape(scope, transposed_read, {2, 3}); + auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 04a5d934064a9083a41cc210b48df65bbc862fff..7bb1ad27467a5b281626de4203169e575288f9ee 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -81,61 +81,27 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } -template -static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { - Tensor linspace(DataTypeToEnum::v(), shape); - auto linspace_flat = linspace.flat(); - for (int64 i = 0; i < depth; ++i) { - linspace_flat(i) = i; - } - return linspace; -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, const xla::XlaOp& off_value, xla::XlaOp* one_hot) { - const int indices_dims = indices_shape.dims(); - const int output_dims = indices_dims + 1; - - TensorShape output_shape = indices_shape; - output_shape.InsertDim(axis, depth); - - // Build a Tensor populated with values 0, 1, 2, ... depth. - std::vector linspace_dims(output_dims, 1); - linspace_dims[axis] = depth; - TensorShape linspace_shape(linspace_dims); - Tensor linspace; - switch (index_type) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(index_type)); - } - - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq( - indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + + TensorShape output_shape = indices_shape; + output_shape.InsertDim(axis, depth); + xla::Shape iota_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); // Selects the user-provided off_value and on_value values. - *one_hot = xla::Select(one_hot_bool, - xla::Broadcast(on_value, output_shape.dim_sizes()), - xla::Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select( + xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index e36128831b4df3a749c8804fa25f7776e83e11c0..ee11f3a3de658c7e5108605122b84fbc3e1cd963 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -319,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } +Status XlaOpKernelContext::ConstantInputAsPartialShape( + int index, PartialTensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + // If `literal` is a scalar it's value must be -1. + if (literal.shape().rank() == 0) { + int64 shape_val; + TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); + if (shape_val != -1) { + return errors::InvalidArgument( + "Cannot convert value to PartialTensorShape: ", shape_val); + } + *shape = PartialTensorShape(); // Shape with unknown rank. + return Status::OK(); + } + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = PartialTensorShape(dims); + return Status::OK(); +} + Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { @@ -513,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, handle = xla::Reshape(handle, xla::AsInt64Slice(representation_shape.dimensions())); } + variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index b3cef40db86c1d073e3236f71f29e4002dcaa0d8..cc2d5e8de3eb020ba41dfed7d730b48cd0534b4c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -138,6 +138,10 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 736588bb8b89ba756cdce77eeebff8d1fcf4774c..ab3a5bdd9bc580c16d65d35c3be3ba8204511f83 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -86,6 +86,12 @@ class XlaResource { // variables have new values that need to be written back. const xla::XlaOp& initial_value() const { return initial_value_; } + // An xla shape that indicates how this resource variable is represented on + // device. + const absl::optional& representation_shape() const { + return representation_shape_; + } + // A variable is initialized if it has a value. bool initialized() const { return value_.valid(); } @@ -100,6 +106,11 @@ class XlaResource { // Sets the current value of the resource to an all-zero value. Status SetZeroValue(xla::XlaBuilder* builder); + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator @@ -160,6 +171,10 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; + // An xla shape that indicates how this resource variable is represented on + // device. + absl::optional representation_shape_; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 9461343542757c0cec89d6ebbbecf0033c9df431..c5dea5f18030f2d226c86e3408ea85b2b5989728 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -452,11 +452,12 @@ cc_library( ) cc_library( - name = "self_adjoint_eigen", - srcs = ["self_adjoint_eigen.cc"], - hdrs = ["self_adjoint_eigen.h"], + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], deps = [ ":arithmetic", + ":comparators", ":constants", ":loops", ":math", @@ -473,8 +474,8 @@ cc_library( ) xla_test( - name = "self_adjoint_eigen_test", - srcs = ["self_adjoint_eigen_test.cc"], + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], blacklisted_backends = [ "cpu", "gpu", @@ -486,7 +487,7 @@ xla_test( ":arithmetic", ":constants", ":matrix", - ":self_adjoint_eigen", + ":self_adjoint_eig", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc similarity index 93% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index 1dc87c29a92faf10c8a9c5df86a26ea46f041d3d..546127e4627f1717913d1039be13fd0c655be1a3 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" #include #include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" @@ -341,6 +342,27 @@ StatusOr> WhileLoopFn( return values; } +StatusOr SortByEigenvalues(SelfAdjointEigResult result) { + XlaBuilder* builder = result.v.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({result.w, result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + result.v = GetTupleElement(sort_result, 1); + return result; +} + } // namespace // This is the cyclic Jacobi iteration. Please note that the eigenvalues are @@ -373,11 +395,11 @@ StatusOr> WhileLoopFn( // // TODO(kuny): Implement parallel order Jacobi. // -SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, - float epsilon) { +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, + float epsilon) { XlaBuilder* builder = a.builder(); auto return_error = [&](const Status& status) { - SelfAdjointEigenResult result; + SelfAdjointEigResult result; result.v = builder->ReportError(status); result.w = builder->ReportError(status); return result; @@ -439,11 +461,11 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, auto output = output_with_status.ValueOrDie(); - SelfAdjointEigenResult result; + SelfAdjointEigResult result; result.v = output[1]; result.w = GetMatrixDiagonal(output[2]); - return result; + return SortByEigenvalues(result).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h similarity index 71% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig.h index 49fc17aa275a8e831e800069290db2dd047416e4..2a089891d6a2d80c0c265a3310539b4f1c5db4d5 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -23,20 +23,18 @@ namespace xla { // The eigenvalue decomposition of a symmetric matrix, the original matrix is // recovered by v * w * v_t. -struct SelfAdjointEigenResult { +struct SelfAdjointEigResult { // The i-th column is the normalized eigenvector corresponding to the // eigenvalue w[i]. Will return a matrix object if a is a matrix object. XlaOp v; - // TODO(kuny): Sort the eigenvalues. // The eigenvalues in ascending order, each repeated according to its // multiplicity. XlaOp w; }; -SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower = true, - int64 max_iter = 100, - float epsilon = 1e-6); +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64 max_iter = 100, float epsilon = 1e-6); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc similarity index 84% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index aa8fa816c095a06833d4afb52e8069fffbd74b41..c8875dff7bfdbd4e133297cef0a6686bfcd9bb6f 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -32,7 +32,7 @@ limitations under the License. namespace xla { -class SelfAdjointEigenTest : public ClientLibraryTestBase { +class SelfAdjointEigTest : public ClientLibraryTestBase { protected: void SetUp() override { ClientLibraryTestBase::SetUp(); @@ -71,7 +71,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { } void TearDown() override { ClientLibraryTestBase::TearDown(); } - Array3D get_unit_matrix_3d(const Array3D& matrix) { + Array3D GetUnitMatrix3D(const Array3D& matrix) { Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); for (int i = 0; i < matrix.n1(); ++i) { for (int j = 0; j < matrix.n2(); ++j) { @@ -100,7 +100,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { return result; } - XlaOp ComputeMatmulVWVt(SelfAdjointEigenResult result, XlaBuilder* builder) { + XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { Shape shape = builder->GetShape(result.v).ValueOrDie(); std::vector out_dims = shape.dimensions(); std::vector broadcast_dims(shape.rank() - 1); @@ -140,69 +140,69 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { Array2D wrong_type_4x4_; }; -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Lower_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter( ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Upper_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter( ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a, false); + auto result = SelfAdjointEig(a, false); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); - ComputeAndCompareR3(&builder, get_unit_matrix_3d(batch_3d_4x4_), + ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR2Parameter(low_rank_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { XlaBuilder builder(TestName()); // This is computed by numpy.linalg.eigh with float32. @@ -211,21 +211,21 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) { XlaOp a; auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); - Sort(result.w); + auto result = SelfAdjointEig(a); + Add(result.w, ZerosLike(result.w)); ComputeAndCompareR1(&builder, expected, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { XlaBuilder builder(TestName()); float expected_vals = 1e-3; XlaOp a; auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), BatchDot(TransposeInMinorDims(result.v), result.v), @@ -235,75 +235,75 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) { ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Wrong_Type_Int) { +XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR2Parameter(wrong_type_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); EXPECT_FALSE(result.v.valid()); EXPECT_FALSE(result.w.valid()); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { XlaBuilder builder(TestName()); int size = 8; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_16x16) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { XlaBuilder builder(TestName()); int size = 16; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_32x32) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { XlaBuilder builder(TestName()); int size = 32; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_256x256) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { XlaBuilder builder(TestName()); int size = 256; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_512x512) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { XlaBuilder builder(TestName()); int size = 512; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 77145ba7d4c72435450d3e33d57b2507eb84d2fc..d7b33c5af25606c4e7e443027b913f7ca13a013c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -134,4 +134,31 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, }); } +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + to_concat.reserve(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (i == dim) { + to_concat.push_back(Reshape(index, index_shape.dimensions())); + } else { + to_concat.push_back(Iota(builder, index_shape, i)); + } + } + XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank()); + std::vector slice_sizes(input_shape.rank(), 1); + GatherDimensionNumbers gather_dnums; + gather_dnums.set_index_vector_dim(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + gather_dnums.add_collapsed_slice_dims(i); + gather_dnums.add_start_index_map(i); + } + return Gather(input, gather_indices, gather_dnums, slice_sizes); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890..69f98a6f43fa167adf6f77b28645a3460b292633 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -43,6 +43,20 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts); +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8d362119e01006555db0f82d02626175936e1d05..db6ebb9df18372260a64a3e9fd17b0c30b35667d 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -102,5 +102,18 @@ XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } +XLA_TEST_F(SlicingTest, TorchGather) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 0}, {1, 0}}, 1, "index", &builder, &index); + TorchGather(input, index, 1); + + ComputeAndCompareR2(&builder, {{1, 1}, {4, 3}}, + {input_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 3245f46e6fd6f365f2e4ee90b3c88cf1bd3b5b85..ddc39f4d874cd3613a763b969091e7e65ff1c783 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -36,7 +36,8 @@ XlaOp TopK(XlaOp input, int64 k) { XlaOp sort_result = Sort({Neg(input), iota_s32}, CreateScalarLtComputation({input_shape.element_type(), S32}, - iota_s32.builder())); + iota_s32.builder()), + last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index ae78910a5b416ceba6da9286b42dde9fb4ebced5..0fbd138aca1e86f219d0459086fc09d20844f135 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -81,9 +81,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } -// TODO(b/122298745): Enable this test when the GPU backend supports stable -// sorting. -XLA_TEST_F(SortingTest, DISABLED_ON_GPU(TopKFullSortWithDuplicates)) { +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index fb9dbe851e7db8da3a496c40a63b39f80cf1ff33..b371b5af37b3b1bf911133a485554f87c8e09183 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1663,14 +1663,16 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, Lt(first_lhs_param, first_rhs_param); TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); - return Sort(operands, comparator, dimension); + return Sort(operands, comparator, dimension, /*is_stable=*/false); }); } XlaOp XlaBuilder::Sort(absl::Span operands, - const XlaComputation& comparator, int64 dimension) { + const XlaComputation& comparator, int64 dimension, + bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(std::vector operand_shapes, GetOperandShapes(operands)); @@ -3320,8 +3322,9 @@ XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { } XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension) { - return operands[0].builder()->Sort(operands, comparator, dimension); + int64 dimension, bool is_stable) { + return operands[0].builder()->Sort(operands, comparator, dimension, + is_stable); } XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 1e39c8766f318f7f31778265dddae6b2b32e111d..fd2e9816e8a0b755b0a1060e8ed4e30c317bd208 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -505,7 +505,7 @@ class XlaBuilder { XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension = -1); + int64 dimension = -1, bool is_stable = false); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -923,7 +923,8 @@ class XlaBuilder { friend XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension); friend XlaOp Sort(absl::Span operands, - const XlaComputation& comparator, int64 dimension); + const XlaComputation& comparator, int64 dimension, + bool is_stable); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1695,7 +1696,8 @@ XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a sort instruction onto the computation, using 'comparator' for -// comparisons. 'comparator' needs to define a strict weak order. +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. // If only one operand is provided: // * If the operand is a rank-1 tensor (an array), the result is a sorted array. // The resulting sorting order has the property that for all index positions @@ -1718,7 +1720,7 @@ XlaOp Sort(const XlaOp& keys, absl::Span values = {}, // correspond to the value of operand i at two index positions. // Default comparator computations can be found in lib/comparators.h XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension = -1); + int64 dimension = -1, bool is_stable = false); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index a9a91648ac377987e7f226116e11c9c697ace103..43d9ee0d9a5e689676b00e59d7c59bb0f4e37461 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -128,11 +128,6 @@ static void AllocateFlags() { tensorflow::Flag( "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), tensorflow::Flag("xla_hlo_dump_as_html", bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), flag_values->xla_hlo_dump_as_html(), @@ -144,13 +139,6 @@ static void AllocateFlags() { flag_values->xla_hlo_graph_sharding_color(), "Assign colors based on sharding assignments when generating the " "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), "HLO modules matching this regex will be dumped to LOG(INFO)."), diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f..230f3b202a4b531c381665471c3856c3feba5a3a 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { } ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( - DeviceAssignment* device_assignment) { + const DeviceAssignment* device_assignment) { device_assignment_ = device_assignment; return *this; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 6f36d11dfb34eb27e79ea4ff797d35f80fb44b27..1e744953bd3be58afba5b81c0e2a8ba26665f9c4 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -74,7 +74,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_device_assignment( - DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; ExecutableRunOptions& set_rng_seed(int rng_seed); @@ -83,7 +83,7 @@ class ExecutableRunOptions { private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; - DeviceAssignment* device_assignment_ = nullptr; + const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index d2f7985aab7123a80ce626b27aa612edda87f761..000c4fdc40519214fa9fa721a8987b77b534442b 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -35,11 +35,11 @@ string Tile::ToString() const { if (dim >= 0) { elements.push_back(std::to_string(dim)); } else { - CHECK_EQ(dim, kCombineDimension) - << "Tile dimension size needs to be mininum int64 value if it's " - "negative. Value is " - << dim; - elements.push_back("*"); + if (dim == kCombineDimension) { + elements.push_back("*"); + } else { + elements.push_back(absl::StrCat("Invalid value ", dim)); + } } } return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); @@ -95,12 +95,24 @@ string Layout::ToString() const { } } +bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { + if (lhs.format() != rhs.format() || + lhs.minor_to_major() != rhs.minor_to_major() || + lhs.max_sparse_elements() != rhs.max_sparse_elements()) { + return false; + } + if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { + return false; + } + if (!ignore_element_size_ && + lhs.element_size_in_bits() != rhs.element_size_in_bits()) { + return false; + } + return true; +} + bool Layout::operator==(const Layout& other) const { - return (other.format() == format() && - other.minor_to_major() == minor_to_major() && - other.element_size_in_bits() == element_size_in_bits() && - other.max_sparse_elements() == max_sparse_elements() && - other.tiles() == tiles()); + return Equal()(*this, other); } std::ostream& operator<<(std::ostream& out, const Tile& tile) { diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index 1faa1629980f5a8d954b70a4d860dbf708de2624..acc449b781b503142b24ed7229e3559230bb1599 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -85,10 +85,12 @@ class Layout { // Constructs a dense tiled layout with the given minor-to-major order and // tiles. - Layout(absl::Span minor_to_major, absl::Span tiles) + Layout(absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits = 0) : format_(DENSE), minor_to_major_(minor_to_major.begin(), minor_to_major.end()), - tiles_(tiles.begin(), tiles.end()) {} + tiles_(tiles.begin(), tiles.end()), + element_size_in_bits_(element_size_in_bits) {} // Construct a shape from a LayoutProto. static Layout CreateFromProto(const LayoutProto& proto); @@ -99,6 +101,37 @@ class Layout { // Returns a human-readable string that represents this layout. string ToString() const; + // Equal is a configurable functor to check the equality of two layouts. + // + // Examples: + // + // - Comparing two layouts ignoring their difference in tiles: + // Equal().IgnoreTiles()(layout1, layout2); + // + // - Comparing two layouts ignoring their difference in tiles and element + // size: + // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); + class Equal { + public: + Equal() = default; + + bool operator()(const Layout& lhs, const Layout& rhs); + + Equal& IgnoreTiles() { + ignore_tiles_ = true; + return *this; + } + + Equal& IgnoreElementSize() { + ignore_element_size_ = true; + return *this; + } + + private: + bool ignore_tiles_ = false; + bool ignore_element_size_ = false; + }; + bool operator==(const Layout& other) const; bool operator!=(const Layout& other) const { return !(*this == other); } @@ -173,7 +206,7 @@ class Layout { element_size_in_bits_ = 0; } - public: + private: // The format of this layout. Format format_ = INVALID_FORMAT; @@ -186,11 +219,11 @@ class Layout { // memory. This field must be zero unless the format is SPARSE. int64 max_sparse_elements_ = 0; - // The number of bits used to store an individual array element. - int64 element_size_in_bits_ = 0; - // The tiles used in tiling-based layout. std::vector tiles_; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; }; std::ostream& operator<<(std::ostream& out, const Tile& Tile); diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc index 7d43b0b87c8eeabf1d30187625a67967a11c3eb4..f5d71c553ed2e0cfd5d5945144dd476557582b5f 100644 --- a/tensorflow/compiler/xla/layout_test.cc +++ b/tensorflow/compiler/xla/layout_test.cc @@ -42,6 +42,9 @@ TEST_F(LayoutTest, ToString) { EXPECT_EQ( Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), "{1,0:T(2,55)E(42)}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({-2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0:T(Invalid value -2,55)E(42)}"); } TEST_F(LayoutTest, StreamOut) { @@ -84,6 +87,15 @@ TEST_F(LayoutTest, Equality) { Layout().set_format(SPARSE).set_max_sparse_elements(42)); EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), Layout().set_format(SPARSE).set_max_sparse_elements(24)); + + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2}))); + EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}))); + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {}, 32), Layout({0, 1, 2}, {}, 1))); + EXPECT_TRUE(Layout::Equal().IgnoreElementSize()(Layout({0, 1, 2}, {}, 32), + Layout({0, 1, 2}, {}, 1))); } TEST_F(LayoutTest, LayoutToFromProto) { diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index f7e2d26b7aa7f23f5e0a2e7623863402ef549789..55eacc1c16a76522215d27ac7cf4e801e69c9740 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -59,10 +59,6 @@ cc_library( srcs = ["local_computation_builder.cc"], hdrs = ["local_computation_builder.h"], deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -77,14 +73,38 @@ cc_library( "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/core:lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "xrt", + srcs = ["xrt.cc"], + hdrs = ["xrt.h"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -92,9 +112,12 @@ cc_library( tf_py_wrap_cc( name = "pywrap_xla", - srcs = ["xla.i"], + srcs = [ + "xla.i", + ], swig_includes = [ "local_computation_builder.i", + "xla_data.i", "//tensorflow/python:platform/base.i", ], version_script = select({ @@ -111,3 +134,27 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla/service:cpu_plugin", ] + xla_python_default_plugins(), ) + +tf_py_wrap_cc( + name = "pywrap_xrt", + srcs = [ + "xrt.i", + ], + swig_includes = [ + "xla_data.i", + "//tensorflow/python:platform/base.i", + ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), + visibility = ["//visibility:public"], + deps = [ + ":numpy_bridge", + ":xrt", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 10d03e9f57fedc464bf0bc922b2eabb7208b8267..a4e5bdb39c227fc2b0294061108e0f44c1b33db4 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,10 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" @@ -32,16 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -50,77 +44,6 @@ limitations under the License. namespace xla { namespace swig { -// TODO(b/118641336): Factor out XRT parts into a small c++ library of their -// own. - -// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of -// device handles instead of needing to set the number of replicas at XLA -// service initialization time. -tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); -int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; -LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; - -string* GetPlatformNameString() { - static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = - new string("Host"); - return platform_name_string; -} - -Status InitializeReplicaCount(int replica_count) { - if (replica_count < 1) { - return InvalidArgument("Replica count must be >= 1; got %d.", - replica_count); - } - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the replica count to %d, but a local XLA service was " - "previously created with a replica count of %d.", - replica_count, g_replica_count); - } - g_replica_count = replica_count; - return Status::OK(); -} - -Status InitializePlatformName(const string& platform_name) { - string* g_platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the platform name to %s, but a local XLA service was " - "previously created with a platform name of %s.", - platform_name, *g_platform_name); - } - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform(platform_name)); - if (platform->VisibleDeviceCount() <= 0) { - return InvalidArgument("Platform %s has no visible devices.", - platform_name); - } - *g_platform_name = platform_name; - return Status::OK(); -} - -int GetReplicaCount() { - tensorflow::mutex_lock lock(g_local_client_mutex); - return g_replica_count; -} - -StatusOr GetOrCreateLocalClient() { - string* platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return g_local_client; - } - LocalClientOptions options; - options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); - options.set_number_of_replicas(g_replica_count); - TF_ASSIGN_OR_RETURN(g_local_client, - ClientLibrary::GetOrCreateLocalClient(options)); - CHECK(g_local_client != nullptr); - return g_local_client; -} - Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; if (!PyCapsule_IsValid(capsule, name)) { @@ -135,62 +58,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { return Status::OK(); } -Status TransferToInfeedLocal(const Literal& literal) { - VLOG(1) << "Infeeding literal without replica number; shape: " - << literal.shape(); - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); -} +LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} -Status TransferToInfeedLocalReplica(const Literal& literal, - int replica_number) { - VLOG(1) << "Infeeding shape " << literal.shape() - << " to replica number: " << replica_number; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferToInfeedLocal(literal, device_ordinal); +/* static */ StatusOr LocalClient::Get( + const string& platform_name) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s has no visible devices.", + platform_name); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + CHECK(client != nullptr); + return LocalClient(client); } -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number) { - VLOG(1) << "Outfeeding literal from replica number: " << replica_number - << " shape: " << shape; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferFromOutfeedLocal(shape, device_ordinal); +// Returns the number of devices known to the XLA client. +int LocalClient::DeviceCount() const { return client_->device_count(); } + +Status LocalClient::TransferToInfeed(const Literal& literal, + int device_ordinal) { + VLOG(1) << "Infeeding literal to device " << device_ordinal + << "; shape: " << literal.shape(); + return client_->TransferToInfeed(literal, device_ordinal); } -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); +StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + VLOG(1) << "Outfeeding literal from device " << device_ordinal + << "; shape: " << shape; + return client_->TransferFromOutfeed(&shape, device_ordinal); } /* static */ StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " - << replica_number << "/" << device_ordinal; + const LocalClient& client, int device_ordinal) { + VLOG(1) << "Creating shaped buffer from literal on device ordinal: " + << device_ordinal; + auto literal_to_buffer = [&](const Literal& arg) { + return client.client()->LiteralToShapedBuffer( + arg, device_ordinal, client.client()->backend().memory_allocator()); + }; + StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, device_ordinal, relaid); + return literal_to_buffer(relaid); } - return ToBuffer(client, device_ordinal, argument); + return literal_to_buffer(argument); }(); TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); } -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) - : shaped_buffer_(std::move(shaped_buffer)) {} +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, + xla::LocalClient* client) + : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; @@ -203,8 +130,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->ShapedBufferToLiteral(*shaped_buffer()); + return client_->ShapedBufferToLiteral(*shaped_buffer()); } LocalShapedBufferTuple::LocalShapedBufferTuple( @@ -235,120 +161,77 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } -XrtAllocation::XrtAllocation(int64 handle, Shape shape, - const string& session_target) - : handle_(handle), shape_(shape), session_target_(session_target) {} - -XrtAllocation::~XrtAllocation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } +StatusOr LocalShapedBuffer::DestructureTuple() { + const Shape tuple_shape = shape(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); } -} - -/* static */ -StatusOr XrtAllocation::FromLiteral( - const Literal& argument, const string& session_target) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = argument.ToProto(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto literal_string = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({literal_string, alloc.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); - - int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape(), session_target); -} - -const int64 XrtAllocation::handle() const { return handle_; } -const Shape& XrtAllocation::shape() const { return shape_; } + DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = Release(); -StatusOr XrtAllocation::ToLiteral() const { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); - TF_RETURN_IF_ERROR(root.status()); + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + ShapeTree& shape_tree = tuple_buffer.buffers(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - xla::LiteralProto response; - TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); - return Literal::CreateFromProto(response); -} + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); -XrtAllocationTuple::XrtAllocationTuple(std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); } -XrtAllocationTuple::~XrtAllocationTuple() { - for (XrtAllocation* element : elements_) { - if (element != nullptr) { - delete element; - } - } -} +LocalExecutable::LocalExecutable( + std::unique_ptr executable, + xla::DeviceAssignment device_assignment, xla::LocalClient* client) + : executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)), + client_(client) {} -StatusOr XrtAllocationTuple::Release(int i) { - XrtAllocation* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); +std::vector LocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); } - elements_[i] = nullptr; - return element; + return device_ordinals; } -int64 XrtAllocationTuple::size() const { return elements_.size(); } - -CompiledLocalComputation::CompiledLocalComputation( - std::unique_ptr executable) - : executable_(std::move(executable)) {} - -StatusOr CompiledLocalComputation::Execute( +StatusOr LocalExecutable::Execute( absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute()", num_replicas()); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - 1, /*computation_count=*/1)); StatusOr result_buffer_status; - const int device_ordinal = device_assignment(0, 0); + const int device_ordinal = device_assignment_(0, 0); VLOG(3) << "Replica 0 mapped to device ordinal for execution: " << device_ordinal; @@ -360,10 +243,10 @@ StatusOr CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); result_buffer_status = executable_->Run(argument_buffers, options); @@ -373,13 +256,13 @@ StatusOr CompiledLocalComputation::Execute( "%s.", result_buffer_status.status().ToString()); } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), + client_); } -StatusOr CompiledLocalComputation::ExecutePerReplica( +StatusOr LocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - const int num_devices = client->device_count(); + const int num_devices = client_->device_count(); if (argument_handles.size() != num_replicas()) { return InvalidArgument( @@ -394,14 +277,9 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( VLOG(1) << "Executing with " << num_replicas() << " replicas."; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - num_replicas(), /*computation_count=*/1)); - std::vector> results(num_replicas()); - auto execute = [this, client, &device_assignment, &argument_handles, - &results](int replica) { - const int device_ordinal = device_assignment(replica, 0); + auto execute = [this, &argument_handles, &results](int replica) { + const int device_ordinal = device_assignment_(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -413,10 +291,10 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); StatusOr result_buffer_status = executable_->Run(argument_buffers, options); @@ -448,145 +326,43 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( replica, statusor.status().ToString()); } wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie()); + new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); } return new LocalShapedBufferTuple(std::move(wrapped_results)); } -CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle, - const string& session_target) - : program_shape_(program_shape), - handle_(handle), - session_target_(session_target) {} - -CompiledXrtComputation::~CompiledXrtComputation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({computation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -StatusOr CompiledXrtComputation::Execute( - absl::Span argument_handles) { - const int num_expected_arguments = program_shape().parameters().size(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - std::vector arguments; - arguments.reserve(num_expected_arguments); - for (int i = 0; i < num_expected_arguments; ++i) { - arguments.push_back( - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); - } - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto execution_config = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto execute = tensorflow::ops::XRTExecute(root, computation_handle, - execution_config, arguments); - TF_RETURN_IF_ERROR(root.status()); - - TF_RET_CHECK(argument_handles.size() == arguments.size()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - for (int i = 0; i < arguments.size(); ++i) { - inputs.insert({arguments[i], argument_handles[i]->handle()}); - } - inputs.insert({computation_handle, handle()}); - inputs.insert({execution_config, e.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); - - int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result(), session_target_); -} - -const ProgramShape& CompiledXrtComputation::program_shape() const { - return program_shape_; -} - -int64 CompiledXrtComputation::handle() const { return handle_; } - -LocalComputation::LocalComputation(XlaComputation computation) +Computation::Computation(XlaComputation computation) : computation_(std::move(computation)) {} -StatusOr LocalComputation::Compile( +StatusOr Computation::Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options) { + const ExecutableBuildOptions* build_options, const LocalClient& client) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { argument_shape_pointers.push_back(&argument_shape); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; } TF_ASSIGN_OR_RETURN( auto local_executable, - client->Compile(computation_, argument_shape_pointers, options)); - return new CompiledLocalComputation(std::move(local_executable)); -} - -StatusOr LocalComputation::CompileForXrt( - const std::vector& argument_shapes, const string& session_target) { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto compile = tensorflow::ops::XRTCompile(root, program); - TF_RETURN_IF_ERROR(root.status()); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - ProgramShape shapes; - for (auto& shape : argument_shapes) { - *shapes.add_parameters() = shape; - } - TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(&shapes); - *config->mutable_program_shape() = shapes.ToProto(); - auto snapshot = computation().Snapshot().ValueOrDie(); - *c.mutable_hlo_snapshot() = *snapshot; - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({program, c.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + client.client()->Compile(computation_, argument_shape_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client.client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation().GetProgramShape()); - int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle, session_target); + return new LocalExecutable(std::move(local_executable), + std::move(device_assignment), client.client()); } -const XlaComputation& LocalComputation::computation() const { - return computation_; -} +const XlaComputation& Computation::computation() const { return computation_; } -string LocalComputation::GetSerializedProto() const { +string Computation::GetSerializedProto() const { string result; if (!computation_.proto().SerializeToString(&result)) { LOG(ERROR) << "Failed to serialize the HloModuleProto."; @@ -595,11 +371,37 @@ string LocalComputation::GetSerializedProto() const { return result; } -StatusOr LocalComputation::GetProgramShape() const { +StatusOr Computation::GetHloText() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(false); + return hlo_module->ToString(options); +} + +StatusOr Computation::GetHloDotGraph() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + hlo_graph_dumper::DotGraphOptions options; + options.debug_options = &hlo_module->config().debug_options(); + return hlo_graph_dumper::HloComputationToDotGraph( + *hlo_module->entry_computation(), options); +} + +StatusOr Computation::GetProgramShape() const { return computation_.GetProgramShape(); } -StatusOr LocalComputation::GetReturnValueShape() const { +StatusOr Computation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); return std::move(*shape.mutable_result()); } @@ -608,93 +410,90 @@ LocalOp::LocalOp(const XlaOp& op) : op_(op) {} const XlaOp& LocalOp::op() const { return op_; } -LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) +ComputationBuilder::ComputationBuilder(const string& computation_name) : builder_(computation_name) {} -void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { +void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); } -void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } +void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } -StatusOr LocalComputationBuilder::Build() { +StatusOr ComputationBuilder::Build() { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, const string& name) { return xla::Parameter(&builder_, parameter_number, shape, name); } -StatusOr LocalComputationBuilder::BuildWithRoot( - const LocalOp& root) { +StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { +StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } -StatusOr LocalComputationBuilder::GetReturnValueShape() { +StatusOr ComputationBuilder::GetReturnValueShape() { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); return program_shape.result(); } -LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp ComputationBuilder::Infeed(const Shape& shape) { return xla::Infeed(&builder_, shape); } -void LocalComputationBuilder::Outfeed(const LocalOp& operand, - const Shape& shape, - const string& outfeed_config) { +void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, + const string& outfeed_config) { xla::Outfeed(operand.op(), shape, outfeed_config); } -LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { +LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } -LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { +LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { return xla::Iota(&builder_, element_type, size); } -LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, - int64 dimension) { +LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { return xla::Iota(&builder_, shape, dimension); } -LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, absl::Span broadcast_sizes) { +LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, + absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } -LocalOp LocalComputationBuilder::BroadcastInDim( +LocalOp ComputationBuilder::BroadcastInDim( const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } -LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { +LocalOp ComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { +LocalOp ComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::AllToAll( +LocalOp ComputationBuilder::AllToAll( const LocalOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, absl::Span replica_groups) { std::vector rg(replica_groups.size()); @@ -705,39 +504,38 @@ LocalOp LocalComputationBuilder::AllToAll( split_count, rg); } -LocalOp LocalComputationBuilder::CrossReplicaSum( +LocalOp ComputationBuilder::CrossReplicaSum( const LocalOp& operand, absl::Span replica_groups) { return xla::CrossReplicaSum(operand.op(), replica_groups); } -LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { +LocalOp ComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } -LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, - int64 limit_index, int64 stride, - int64 dimno) { +LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } -LocalOp LocalComputationBuilder::DynamicSlice( - const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes) { +LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, + const LocalOp& start_indices, + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -LocalOp LocalComputationBuilder::DynamicUpdateSlice( - const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices) { +LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, + const LocalOp& update, + const LocalOp& start_indices) { return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { +LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -746,18 +544,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, return xla::ConcatInDim(&builder_, xla_ops, dimension); } -LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, +LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter) { + const LocalOp& init_value, const Computation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { +LocalOp ComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -767,22 +565,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { return xla::Tuple(&builder_, xla_ops); } -LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { +LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { return xla::GetTupleElement(tuple_data.op(), index); } -LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { +LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { return xla::Dot(lhs.op(), rhs.op()); } -LocalOp LocalComputationBuilder::DotGeneral( +LocalOp ComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -LocalOp LocalComputationBuilder::ConvGeneralDilated( +LocalOp ComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -794,18 +592,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( feature_group_count); } -LocalOp LocalComputationBuilder::ConvertElementType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::ConvertElementType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::BitcastConvertType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, - absl::Span operands) { +LocalOp ComputationBuilder::Call(const Computation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -814,7 +612,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } -LocalOp LocalComputationBuilder::CustomCall( +LocalOp ComputationBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const std::vector& operand_shapes_with_layout, @@ -829,19 +627,19 @@ LocalOp LocalComputationBuilder::CustomCall( operand_shapes_with_layout, opaque); } -LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, absl::Span permutation) { +LocalOp ComputationBuilder::Transpose(const LocalOp& operand, + absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map(absl::Span operands, - const LocalComputation& local_computation, - absl::Span dimensions) { +LocalOp ComputationBuilder::Map(absl::Span operands, + const Computation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -852,17 +650,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span operands, dimensions); } -LocalOp LocalComputationBuilder::Reduce( +LocalOp ComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } -LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( +LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -874,51 +672,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( padding); } -LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, - const LocalOp& sigma, - const Shape& shape) { +LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape) { return xla::RngNormal(mu.op(), sigma.op(), shape); } -LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { +LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { return xla::RngUniform(a.op(), b.op(), shape); } -LocalOp LocalComputationBuilder::While(const LocalComputation& condition, - const LocalComputation& body, - const LocalOp& init) { +LocalOp ComputationBuilder::While(const Computation& condition, + const Computation& body, + const LocalOp& init) { return xla::While(condition.computation(), body.computation(), init.op()); } -LocalOp LocalComputationBuilder::Conditional( - const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation) { +LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, + const LocalOp& true_operand, + const Computation& true_computation, + const LocalOp& false_operand, + const Computation& false_computation) { return xla::Conditional(predicate.op(), true_operand.op(), true_computation.computation(), false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { +StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } -LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { +LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { return xla::Sort(operand.op(), {}, dimension); } -LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, - int64 dimension) { +LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, int64 dimension) { return xla::Sort(keys.op(), {values.op()}, dimension); } -LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { +LocalOp ComputationBuilder::Cholesky(const LocalOp& a) { return xla::Cholesky(a.op()); } -LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { +LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { XlaBuilder* builder = a.op().builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); @@ -926,17 +723,16 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { }); } -LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, - const LocalOp& b, - bool left_side, bool lower, - bool unit_diagonal, - int transpose_a) { +LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, + bool left_side, bool lower, + bool unit_diagonal, + int transpose_a) { return xla::TriangularSolve( a.op(), b.op(), left_side, lower, unit_diagonal, xla::TriangularSolveOptions::Transpose(transpose_a)); } -LocalOp LocalComputationBuilder::Gather( +LocalOp ComputationBuilder::Gather( const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes) { @@ -944,24 +740,24 @@ LocalOp LocalComputationBuilder::Gather( slice_sizes); } -LocalOp LocalComputationBuilder::Scatter( +LocalOp ComputationBuilder::Scatter( const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers) { return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), update_computation.computation(), dimension_numbers); } -StatusOr LocalComputationBuilder::BuildConstantSubGraph( +StatusOr ComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.BuildConstantSubGraph(operand.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig LocalComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig ComputationBuilder::method_name args_sig { \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -1048,108 +844,9 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; } -void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } - -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { - delete computation; -} - -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { - delete computation; -} - -void DeleteLocalComputation(LocalComputation* computation) { - delete computation; -} - -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer) { - const Shape tuple_shape = local_shaped_buffer->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - DeviceMemoryAllocator* allocator = - local_shaped_buffer->shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); - - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator))); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} - -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target) { - const Shape& tuple_shape = allocation->shape(); +void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); - auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - std::vector results; - for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - inputs.clear(); - inputs.insert({base_handle, allocation->handle()}); - inputs.insert({shape_index, {i}}); - std::vector outputs; - auto status = session.Run(inputs, {subtuple}, &outputs); - if (!status.ok()) { - // Clean up before returning non-ok status. - for (int j = 0; j < results.size(); ++j) { - delete results[j]; - } - return status; - } - const int64 subtuple_handle = outputs[0].scalar()(); - const Shape& subtuple_shape = - ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back( - new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); - } - return new XrtAllocationTuple(std::move(results)); -} +void DeleteComputation(Computation* computation) { delete computation; } } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index f62b2b6c723981a1e5c94acdff621485e4c6ca93..74996d2e6b6101e8accf592f94bf6d9c95685f10 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -22,9 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -35,42 +32,42 @@ limitations under the License. namespace xla { namespace swig { -// Initializes the number of replicas that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializeReplicaCount(int replica_count); - -// Initializes the platform name that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializePlatformName(const string& platform_name); - -// Returns the replica count that is currently set, regardless of whether the -// local XLA service has been instantiated yet or not. -int GetReplicaCount(); - // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // "xla._CPU_CUSTOM_CALL_TARGET". Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); -// Wraps the local client's infeed-transfer function. -// -// The default device ordinal (0) is used. -Status TransferToInfeedLocal(const Literal& literal); +// Wrapper around an xla::LocalClient. +class LocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + /// such platform exists, or if the platform has no visible devices. + static StatusOr Get(const string& platform_name); + + // Copyable and moveable; the class is just a wrapper around a + // xla::LocalClient pointer for convenient SWIG wrapping. + + // Returns the number of devices known to the XLA client. + int DeviceCount() const; -// Transfers the given literal to the infeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + // Wraps the local client's infeed-transfer function. + // + // The default device ordinal (0) is used. + Status TransferToInfeed(const Literal& literal, int device_ordinal); + + // Transfers a literal of the given shape from the outfeed of the given + // replica. + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); + + xla::LocalClient* client() const { return client_; } + + private: + LocalClient(xla::LocalClient* client); + + xla::LocalClient* client_; +}; -// Transfers a literal of the given shape from the outfeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number); +class LocalShapedBufferTuple; // Represents a reference to literals that live in a device-allocated buffer via // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a @@ -79,9 +76,9 @@ class LocalShapedBuffer { public: static StatusOr FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number); + const LocalClient& client, int device_ordinal); - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); StatusOr ToLiteral() const; const Shape& shape() const; const ScopedShapedBuffer* shaped_buffer() const; @@ -90,8 +87,13 @@ class LocalShapedBuffer { // analogous to std::unique_ptr::release(). ShapedBuffer Release(); + // Destructures a tuple-valued LocalShapedBuffer into its constitutent + // elements in LocalShapedBufferTuple form. + StatusOr DestructureTuple(); + private: ScopedShapedBuffer shaped_buffer_; + xla::LocalClient* client_; }; // Result of a tuple destructuring operation on a LocalShapedBuffer -- this @@ -117,73 +119,21 @@ class LocalShapedBufferTuple { std::vector elements_; }; -// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements -// in LocalShapedBufferTuple form. -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer); - -// Represents a reference to literals that live in a device-allocated buffer via -// XRT. Specifically, wraps an int64 handle produced by running the allocation -// graph, and an XLA shape to track the referent's shape. -class XrtAllocation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which allocation and deallocation - // graphs are run. - static StatusOr FromLiteral(const Literal& argument, - const string& session_target); - - XrtAllocation(int64 handle, Shape shape, const string& session_target); - ~XrtAllocation(); - StatusOr ToLiteral() const; - const Shape& shape() const; - const int64 handle() const; - - private: - const int64 handle_; - const Shape shape_; - const string session_target_; -}; - -// Result of a tuple destructuring operation on an XrtAllocation. -class XrtAllocationTuple { - public: - // Note: any XrtAllocation elements that are not Release()'d will be - // deallocated in the destructor. - explicit XrtAllocationTuple(std::vector elements); - - ~XrtAllocationTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Destructures a tuple-valued XrtAllocation into its constitutent elements -// in XrtAllocationTuple form. -// -// Accepts a `session_target` argument, used in constructing the -// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, -// and passed along in constructing each constituent XrtAllocation. -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target); - // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class CompiledLocalComputation { +class LocalExecutable { public: - CompiledLocalComputation(std::unique_ptr executable); + LocalExecutable(std::unique_ptr executable, + xla::DeviceAssignment device_assignment, + xla::LocalClient* client); int num_replicas() const { return executable_->build_options().num_replicas(); } + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; + StatusOr Execute( absl::Span argument_handles); @@ -194,47 +144,22 @@ class CompiledLocalComputation { absl::Span > argument_handles); private: - std::unique_ptr executable_; -}; - -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XRT computation handle. -class CompiledXrtComputation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the execution graph is run. - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~CompiledXrtComputation(); - - StatusOr Execute( - absl::Span argument_handles); - - const ProgramShape& program_shape() const; - int64 handle() const; - - private: - const ProgramShape program_shape_; - const int64 handle_; - const string session_target_; + const std::unique_ptr executable_; + const xla::DeviceAssignment device_assignment_; + xla::LocalClient* const client_; }; -// Wraps a XlaComputation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a ComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. -class LocalComputation { +class Computation { public: - LocalComputation(XlaComputation computation); + Computation(XlaComputation computation); - StatusOr Compile( + StatusOr Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options); - - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the compilation graph is run. - StatusOr CompileForXrt( - const std::vector& argument_shapes, const string& session_target); + const ExecutableBuildOptions* build_options, const LocalClient& client); const XlaComputation& computation() const; @@ -243,6 +168,12 @@ class LocalComputation { // string on failure. string GetSerializedProto() const; + // Returns the computation in human-readable HLO text format. + StatusOr GetHloText() const; + + // Returns the computation in graphviz dot format. + StatusOr GetHloDotGraph() const; + // Returns the program shape for this computation. StatusOr GetProgramShape() const; @@ -253,7 +184,7 @@ class LocalComputation { XlaComputation computation_; }; -// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// Wraps a XlaOp produced by a ComputationBuilder. This class is intended // to be made available to Python via SWIG. class LocalOp { public: @@ -270,20 +201,20 @@ class LocalOp { // Python. // - Set up the underlying builder to use the client library's // LocalClient. -// - Wrap Computations in LocalComputations for Python access. -// - Correspondingly unwrap incoming LocalComputations. -class LocalComputationBuilder { +// - Wrap Computations in Computations for Python access. +// - Correspondingly unwrap incoming Computations. +class ComputationBuilder { public: - LocalComputationBuilder(const string& computation_name); + ComputationBuilder(const string& computation_name); void SetOpMetadata(const OpMetadata& metadata); void ClearOpMetadata(); - // Returns an owned LocalComputation to the caller on success. - StatusOr Build(); + // Returns an owned Computation to the caller on success. + StatusOr Build(); - // Returns an owned LocalComputation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); + // Returns an owned Computation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -342,11 +273,11 @@ class LocalComputationBuilder { LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter); + const LocalOp& init_value, const Computation& scatter); LocalOp Tuple(absl::Span elements); @@ -372,7 +303,7 @@ class LocalComputationBuilder { LocalOp BitcastConvertType(const LocalOp& operand, PrimitiveType new_element_type); - LocalOp Call(const LocalComputation& local_computation, + LocalOp Call(const Computation& local_computation, absl::Span operands); LocalOp CustomCall(const string& call_target_name, @@ -387,16 +318,16 @@ class LocalComputationBuilder { LocalOp Rev(const LocalOp& operand, absl::Span dimensions); LocalOp Map(absl::Span operands, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -408,13 +339,13 @@ class LocalComputationBuilder { LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - LocalOp While(const LocalComputation& condition, const LocalComputation& body, + LocalOp While(const Computation& condition, const Computation& body, const LocalOp& init); LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, + const Computation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation); + const Computation& false_computation); StatusOr IsConstant(const LocalOp& operand); @@ -438,11 +369,10 @@ class LocalComputationBuilder { absl::Span slice_sizes); LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, - const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - StatusOr BuildConstantSubGraph(const LocalOp& operand); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -531,10 +461,8 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); -void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); -void DeleteLocalComputation(LocalComputation* computation); +void DeleteLocalExecutable(LocalExecutable* computation); +void DeleteComputation(Computation* computation); } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 688fcf9f4d09ff9d5c73165d7b3ccc7a4d2d4f09..adce433b9628801b91d02643eecfcccfa6509692 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -23,6 +23,7 @@ limitations under the License. // C++ Python // -------------------------------------+--------------------------------------- // Span <- sequence of int +// vector -> sequence of int // Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray @@ -98,7 +99,7 @@ limitations under the License. // wrapped in a Python class (xla_client.Shape) so as not to expose // the raw pair externally. // -// Other SWIG object wrappers (e.g. of LocalComputation) are further +// Other SWIG object wrappers (e.g. of Computation) are further // wrapped by xla_client in order to set up a custom destructor that // triggers memory deallocation on the C++ side. @@ -108,6 +109,7 @@ limitations under the License. %nothread; %include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" %{ // Must be included first @@ -125,87 +127,6 @@ limitations under the License. using namespace xla; using namespace xla::swig; -namespace xla { - -namespace swig { - -bool GetIntAttr(PyObject* o, const char* field, int64* result) { - PyObject* fo = PyObject_GetAttrString(o, field); - if (!fo) { - return false; - } - const int64 value = numpy::PyIntOrPyLongToLong(fo); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(fo); - return false; - } - Py_DECREF(fo); - *result = value; - return true; -} - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, - const char* attr_name, - std::function f) { - if (!PyObject_HasAttrString(o, attr_name)) { - return true; // It's ok for the object to not have the attribute. - } - PyObject* attr = PyObject_GetAttrString(o, attr_name); - if (attr == nullptr) { - return false; // An error occurred getting the attribute. - } - if (attr == Py_None) { - Py_DECREF(attr); - return true; // The attribute is None, which we consider ok. - } - if (!PyString_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyString_AsString(attr)); - Py_DECREF(attr); - return true; // Handled string attribute, ok! -} - -bool HandleRepeatedInt64Attribute( - PyObject* o, const char* attr_name, - tensorflow::protobuf::RepeatedField* field) { - PyObject* seq = PyObject_GetAttrString(o, attr_name); - if (!seq) { - return false; - } - - int length = PySequence_Size(seq); - if (length == -1) { - Py_DECREF(seq); - return false; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(seq, i); - if (!item) { - Py_DECREF(seq); - return false; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(seq); - return false; - } - *field->Add() = dimension; - Py_DECREF(item); - } - Py_DECREF(seq); - return true; -} - -} // namespace swig -} // namespace xla %} // Required to use PyArray_* functions. @@ -213,57 +134,6 @@ bool HandleRepeatedInt64Attribute( tensorflow::ImportNumpy(); %} -// Basic types - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.resize(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - Py_DECREF(py_int); - Py_DECREF(o); - } - $1 = temps; -} - // Computation builder types %typemap(in) absl::Span( @@ -288,12 +158,12 @@ tensorflow::ImportNumpy(); // Computation and buffer/allocation types -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - auto* value = $1.ValueOrDie(); + xla::swig::LocalClient value = $1.ValueOrDie(); { - auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + auto $1 = value; + $typemap(out, xla::swig::LocalClient) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -301,12 +171,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledXrtComputation*) + $typemap(out, xla::swig::LocalExecutable*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -340,38 +210,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocationTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::Computation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -431,485 +275,6 @@ tensorflow::ImportNumpy(); $1 = temps; } -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - XrtAllocation* xrta; - if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(xrta); - Py_DECREF(o); - } - $1 = temps; -} - -// Literal - -%typemap(in) const Literal& (StatusOr literal_status) { - literal_status = numpy::XlaLiteralFromPyObject($input); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - SWIG_fail; - } - $1 = &literal_status.ValueOrDie(); -} - -%typemap(out) Literal (StatusOr obj_status) { - obj_status = numpy::PyObjectFromXlaLiteral(*$1); - if (!obj_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); - SWIG_fail; - } - $result = obj_status.ValueOrDie().release(); -} - -%typemap(out) StatusOr (StatusOr obj_status) { - if (!$1.ok()) { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } - obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); - if (!obj_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); - SWIG_fail; - } - $result = obj_status.ValueOrDie().release(); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - Py_DECREF(o); - SWIG_fail; - } - temps.push_back(literal_status.ConsumeValueOrDie()); - Py_DECREF(o); - } - $1 = &temps; -} - -// OpMetadata - -%typemap(in) const OpMetadata& (OpMetadata temp) { - StatusOr statusor = numpy::OpMetadataFromPyObject($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -// Shape - -%typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyProgramShapeInfoFromXlaProgramShape( - $1.ConsumeValueOrDie()).release(); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - - -%typemap(in) const Shape& (Shape temp) { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -%typemap(in) const absl::optional& ( - absl::optional temp) { - if ($input == Py_None) { - temp = absl::nullopt; - $1 = &temp; - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; - } -} - -%typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - $1 = &temps; -} - -%typemap(in) const std::vector >& ( - std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (o == Py_None) { - temps.push_back(absl::nullopt); - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - } - $1 = &temps; -} - -// PrimitiveType - -%typemap(in) PrimitiveType { - PyObject* py_int = numpy::PyNumberToPyInt($input); - if (!py_int) { - PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - SWIG_fail; - } - const long value = numpy::PyIntOrPyLongToLong(py_int); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - SWIG_fail; - } - if (!PrimitiveType_IsValid(value)) { - PyErr_SetString( - PyExc_TypeError, "Argument not valid for PrimitiveType enum"); - Py_DECREF(py_int); - SWIG_fail; - } - $1 = static_cast(value); -} - -// Span> - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (!o) { - SWIG_fail; - } - PyObject* first = PyTuple_GetItem(o, 0); - if (!first) { - Py_DECREF(o); - SWIG_fail; - } - PyObject* first_pyint = numpy::PyNumberToPyInt(first); - if (!first_pyint) { - PyErr_SetString( - PyExc_TypeError, - "First pair item cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - PyObject* second_pyint = numpy::PyNumberToPyInt(second); - if (!second_pyint) { - PyErr_SetString( - PyExc_TypeError, - "Second pair item cannot be converted to int"); - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); - if (first_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); - if (second_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - temps.push_back(std::make_pair(first_value, second_value)); - Py_DECREF(o); - } - $1 = temps; -} - -// DotDimensionNumbers - -%typemap(in) const DotDimensionNumbers& - (DotDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "lhs_contracting_dimensions", - dimension_numbers.mutable_lhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_contracting_dimensions", - dimension_numbers.mutable_rhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "lhs_batch_dimensions", - dimension_numbers.mutable_lhs_batch_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_batch_dimensions", - dimension_numbers.mutable_rhs_batch_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// PaddingConfig - -%typemap(in) const PaddingConfig& - (PaddingConfig padding_config) { - PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); - if (!dimensions) { - SWIG_fail; - } - - int length = PySequence_Size(dimensions); - if (length == -1) { - Py_DECREF(dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(dimensions, i); - if (!item) { - Py_DECREF(dimensions); - SWIG_fail; - } - int64 edge_padding_low, edge_padding_high, interior_padding; - if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) - || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) - || !GetIntAttr(item, "interior_padding", &interior_padding)) { - Py_DECREF(item); - Py_DECREF(dimensions); - SWIG_fail; - } - Py_DECREF(item); - - PaddingConfig::PaddingConfigDimension* dimension = - padding_config.add_dimensions(); - dimension->set_edge_padding_low(edge_padding_low); - dimension->set_edge_padding_high(edge_padding_high); - dimension->set_interior_padding(interior_padding); - } - Py_DECREF(dimensions); - - $1 = &padding_config; -} - -// ConvolutionDimensionNumbers - -%typemap(in) const ConvolutionDimensionNumbers& - (ConvolutionDimensionNumbers dimension_numbers) { - int64 value; - - if (!GetIntAttr($input, "input_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_batch_dimension(value); - - if (!GetIntAttr($input, "input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_feature_dimension(value); - - if (!GetIntAttr($input, "output_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_batch_dimension(value); - - if (!GetIntAttr($input, "output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_input_feature_dimension(value); - - if (!HandleRepeatedInt64Attribute( - $input, "input_spatial_dimensions", - dimension_numbers.mutable_input_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "kernel_spatial_dimensions", - dimension_numbers.mutable_kernel_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "output_spatial_dimensions", - dimension_numbers.mutable_output_spatial_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// GatherDimensionNumbers - -%typemap(in) const GatherDimensionNumbers& - (GatherDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "offset_dims", - dimension_numbers.mutable_offset_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "collapsed_slice_dims", - dimension_numbers.mutable_collapsed_slice_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "start_index_map", - dimension_numbers.mutable_start_index_map())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// ScatterDimensionNumbers - -%typemap(in) const ScatterDimensionNumbers& - (ScatterDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "update_window_dims", - dimension_numbers.mutable_update_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "inserted_window_dims", - dimension_numbers.mutable_inserted_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "scatter_dims_to_operand_dims", - dimension_numbers.mutable_scatter_dims_to_operand_dims())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// Span - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - ReplicaGroup rgrp; - if (!HandleRepeatedInt64Attribute( - o, "replica_ids", - rgrp.mutable_replica_ids())) { - SWIG_fail; - } - temps.push_back(rgrp); - Py_DECREF(o); - } - $1 = temps; -} - - // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -979,161 +344,150 @@ tensorflow::ImportNumpy(); %ignoreall %unignore xla; %unignore xla::swig; -%unignore xla::swig::InitializeReplicaCount; -%unignore xla::swig::InitializePlatformName; -%unignore xla::swig::GetReplicaCount; %unignore xla::swig::RegisterCpuCustomCallTarget; -%unignore xla::swig::TransferToInfeedLocal; -%unignore xla::swig::TransferToInfeedLocalReplica; -%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalClient; +%unignore xla::swig::LocalClient::Get; +%unignore xla::swig::LocalClient::DeviceCount; +%unignore xla::swig::LocalClient::TransferToInfeed; +%unignore xla::swig::LocalClient::TransferFromOutfeed; %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::shape; +%unignore xla::swig::LocalShapedBuffer::DestructureTuple; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; -%unignore xla::swig::XrtAllocation; -%unignore xla::swig::XrtAllocation::FromLiteral; -%unignore xla::swig::XrtAllocation::ToLiteral; -%unignore xla::swig::XrtAllocation::shape; -%unignore xla::swig::XrtAllocationTuple; -%unignore xla::swig::XrtAllocationTuple::Release; -%unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::CompiledLocalComputation; -%unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; -%unignore xla::swig::CompiledXrtComputation; -%unignore xla::swig::CompiledXrtComputation::Execute; -%unignore xla::swig::LocalComputation; -%unignore xla::swig::LocalComputation::Compile; -%unignore xla::swig::LocalComputation::CompileForXrt; -%unignore xla::swig::LocalComputation::GetProgramShape; -%unignore xla::swig::LocalComputation::GetReturnValueShape; -%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalExecutable; +%unignore xla::swig::LocalExecutable::DeviceOrdinals; +%unignore xla::swig::LocalExecutable::Execute; +%unignore xla::swig::LocalExecutable::ExecutePerReplica; +%unignore xla::swig::Computation; +%unignore xla::swig::Computation::Compile; +%unignore xla::swig::Computation::GetProgramShape; +%unignore xla::swig::Computation::GetReturnValueShape; +%unignore xla::swig::Computation::GetSerializedProto; +%unignore xla::swig::Computation::GetHloText; +%unignore xla::swig::Computation::GetHloDotGraph; %unignore xla::swig::LocalOp; -%unignore xla::swig::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::Build; -%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; -%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; -%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; -%unignore xla::swig::LocalComputationBuilder::Parameter; -%unignore xla::swig::LocalComputationBuilder::GetShape; -%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; -%unignore xla::swig::LocalComputationBuilder::Infeed; -%unignore xla::swig::LocalComputationBuilder::Outfeed; -%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; -%unignore xla::swig::LocalComputationBuilder::ConstantR0; -%unignore xla::swig::LocalComputationBuilder::Iota; -%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; -%unignore xla::swig::LocalComputationBuilder::Broadcast; -%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; -%unignore xla::swig::LocalComputationBuilder::Pad; -%unignore xla::swig::LocalComputationBuilder::Reshape; -%unignore xla::swig::LocalComputationBuilder::Collapse; -%unignore xla::swig::LocalComputationBuilder::AllToAll; -%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; -%unignore xla::swig::LocalComputationBuilder::Slice; -%unignore xla::swig::LocalComputationBuilder::SliceInDim; -%unignore xla::swig::LocalComputationBuilder::DynamicSlice; -%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::LocalComputationBuilder::ConcatInDim; -%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::Select; -%unignore xla::swig::LocalComputationBuilder::Tuple; -%unignore xla::swig::LocalComputationBuilder::GetTupleElement; -%unignore xla::swig::LocalComputationBuilder::ConvertElementType; -%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; -%unignore xla::swig::LocalComputationBuilder::Call; -%unignore xla::swig::LocalComputationBuilder::Transpose; -%unignore xla::swig::LocalComputationBuilder::Rev; -%unignore xla::swig::LocalComputationBuilder::Clamp; -%unignore xla::swig::LocalComputationBuilder::Map; -%unignore xla::swig::LocalComputationBuilder::Reduce; -%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::RngNormal; -%unignore xla::swig::LocalComputationBuilder::RngUniform; -%unignore xla::swig::LocalComputationBuilder::RngBernoulli; -%unignore xla::swig::LocalComputationBuilder::While; -%unignore xla::swig::LocalComputationBuilder::Conditional; -%unignore xla::swig::LocalComputationBuilder::IsConstant; -%unignore xla::swig::LocalComputationBuilder::Eq; -%unignore xla::swig::LocalComputationBuilder::Ne; -%unignore xla::swig::LocalComputationBuilder::Ge; -%unignore xla::swig::LocalComputationBuilder::Gt; -%unignore xla::swig::LocalComputationBuilder::Lt; -%unignore xla::swig::LocalComputationBuilder::Le; -%unignore xla::swig::LocalComputationBuilder::Dot; -%unignore xla::swig::LocalComputationBuilder::DotGeneral; -%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::LocalComputationBuilder::Add; -%unignore xla::swig::LocalComputationBuilder::Sub; -%unignore xla::swig::LocalComputationBuilder::Mul; -%unignore xla::swig::LocalComputationBuilder::Div; -%unignore xla::swig::LocalComputationBuilder::Rem; -%unignore xla::swig::LocalComputationBuilder::Max; -%unignore xla::swig::LocalComputationBuilder::Min; -%unignore xla::swig::LocalComputationBuilder::And; -%unignore xla::swig::LocalComputationBuilder::Or; -%unignore xla::swig::LocalComputationBuilder::Xor; -%unignore xla::swig::LocalComputationBuilder::ShiftLeft; -%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; -%unignore xla::swig::LocalComputationBuilder::Not; -%unignore xla::swig::LocalComputationBuilder::Abs; -%unignore xla::swig::LocalComputationBuilder::Exp; -%unignore xla::swig::LocalComputationBuilder::Expm1; -%unignore xla::swig::LocalComputationBuilder::Floor; -%unignore xla::swig::LocalComputationBuilder::Ceil; -%unignore xla::swig::LocalComputationBuilder::Round; -%unignore xla::swig::LocalComputationBuilder::Log; -%unignore xla::swig::LocalComputationBuilder::Log1p; -%unignore xla::swig::LocalComputationBuilder::Sign; -%unignore xla::swig::LocalComputationBuilder::Cos; -%unignore xla::swig::LocalComputationBuilder::Sin; -%unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::Atan2; -%unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::Pow; -%unignore xla::swig::LocalComputationBuilder::Neg; -%unignore xla::swig::LocalComputationBuilder::Sort; -%unignore xla::swig::LocalComputationBuilder::SortKeyVal; -%unignore xla::swig::LocalComputationBuilder::Sqrt; -%unignore xla::swig::LocalComputationBuilder::Rsqrt; -%unignore xla::swig::LocalComputationBuilder::Square; -%unignore xla::swig::LocalComputationBuilder::Reciprocal; -%unignore xla::swig::LocalComputationBuilder::Erfc; -%unignore xla::swig::LocalComputationBuilder::Erf; -%unignore xla::swig::LocalComputationBuilder::ErfInv; -%unignore xla::swig::LocalComputationBuilder::Lgamma; -%unignore xla::swig::LocalComputationBuilder::Digamma; -%unignore xla::swig::LocalComputationBuilder::Acos; -%unignore xla::swig::LocalComputationBuilder::Asin; -%unignore xla::swig::LocalComputationBuilder::Atan; -%unignore xla::swig::LocalComputationBuilder::Tan; -%unignore xla::swig::LocalComputationBuilder::Acosh; -%unignore xla::swig::LocalComputationBuilder::Asinh; -%unignore xla::swig::LocalComputationBuilder::Atanh; -%unignore xla::swig::LocalComputationBuilder::Cosh; -%unignore xla::swig::LocalComputationBuilder::Sinh; -%unignore xla::swig::LocalComputationBuilder::Real; -%unignore xla::swig::LocalComputationBuilder::Imag; -%unignore xla::swig::LocalComputationBuilder::Conj; -%unignore xla::swig::LocalComputationBuilder::Complex; -%unignore xla::swig::LocalComputationBuilder::Cholesky; -%unignore xla::swig::LocalComputationBuilder::QR; -%unignore xla::swig::LocalComputationBuilder::TriangularSolve; -%unignore xla::swig::LocalComputationBuilder::CustomCall; -%unignore xla::swig::LocalComputationBuilder::Gather; -%unignore xla::swig::LocalComputationBuilder::Scatter; -%unignore xla::swig::DeleteLocalComputation; -%unignore xla::swig::DestructureLocalShapedBufferTuple; -%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::Build; +%unignore xla::swig::ComputationBuilder::BuildWithRoot; +%unignore xla::swig::ComputationBuilder::SetOpMetadata; +%unignore xla::swig::ComputationBuilder::ClearOpMetadata; +%unignore xla::swig::ComputationBuilder::Parameter; +%unignore xla::swig::ComputationBuilder::GetShape; +%unignore xla::swig::ComputationBuilder::GetReturnValueShape; +%unignore xla::swig::ComputationBuilder::Infeed; +%unignore xla::swig::ComputationBuilder::Outfeed; +%unignore xla::swig::ComputationBuilder::ConstantLiteral; +%unignore xla::swig::ComputationBuilder::ConstantR0; +%unignore xla::swig::ComputationBuilder::Iota; +%unignore xla::swig::ComputationBuilder::BroadcastedIota; +%unignore xla::swig::ComputationBuilder::Broadcast; +%unignore xla::swig::ComputationBuilder::BroadcastInDim; +%unignore xla::swig::ComputationBuilder::Pad; +%unignore xla::swig::ComputationBuilder::Reshape; +%unignore xla::swig::ComputationBuilder::Collapse; +%unignore xla::swig::ComputationBuilder::AllToAll; +%unignore xla::swig::ComputationBuilder::CrossReplicaSum; +%unignore xla::swig::ComputationBuilder::Slice; +%unignore xla::swig::ComputationBuilder::SliceInDim; +%unignore xla::swig::ComputationBuilder::DynamicSlice; +%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::ComputationBuilder::ConcatInDim; +%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::Select; +%unignore xla::swig::ComputationBuilder::Tuple; +%unignore xla::swig::ComputationBuilder::GetTupleElement; +%unignore xla::swig::ComputationBuilder::ConvertElementType; +%unignore xla::swig::ComputationBuilder::BitcastConvertType; +%unignore xla::swig::ComputationBuilder::Call; +%unignore xla::swig::ComputationBuilder::Transpose; +%unignore xla::swig::ComputationBuilder::Rev; +%unignore xla::swig::ComputationBuilder::Clamp; +%unignore xla::swig::ComputationBuilder::Map; +%unignore xla::swig::ComputationBuilder::Reduce; +%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::RngNormal; +%unignore xla::swig::ComputationBuilder::RngUniform; +%unignore xla::swig::ComputationBuilder::RngBernoulli; +%unignore xla::swig::ComputationBuilder::While; +%unignore xla::swig::ComputationBuilder::Conditional; +%unignore xla::swig::ComputationBuilder::IsConstant; +%unignore xla::swig::ComputationBuilder::Eq; +%unignore xla::swig::ComputationBuilder::Ne; +%unignore xla::swig::ComputationBuilder::Ge; +%unignore xla::swig::ComputationBuilder::Gt; +%unignore xla::swig::ComputationBuilder::Lt; +%unignore xla::swig::ComputationBuilder::Le; +%unignore xla::swig::ComputationBuilder::Dot; +%unignore xla::swig::ComputationBuilder::DotGeneral; +%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::ComputationBuilder::Add; +%unignore xla::swig::ComputationBuilder::Sub; +%unignore xla::swig::ComputationBuilder::Mul; +%unignore xla::swig::ComputationBuilder::Div; +%unignore xla::swig::ComputationBuilder::Rem; +%unignore xla::swig::ComputationBuilder::Max; +%unignore xla::swig::ComputationBuilder::Min; +%unignore xla::swig::ComputationBuilder::And; +%unignore xla::swig::ComputationBuilder::Or; +%unignore xla::swig::ComputationBuilder::Xor; +%unignore xla::swig::ComputationBuilder::ShiftLeft; +%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::ComputationBuilder::ShiftRightLogical; +%unignore xla::swig::ComputationBuilder::Not; +%unignore xla::swig::ComputationBuilder::Abs; +%unignore xla::swig::ComputationBuilder::Exp; +%unignore xla::swig::ComputationBuilder::Expm1; +%unignore xla::swig::ComputationBuilder::Floor; +%unignore xla::swig::ComputationBuilder::Ceil; +%unignore xla::swig::ComputationBuilder::Round; +%unignore xla::swig::ComputationBuilder::Log; +%unignore xla::swig::ComputationBuilder::Log1p; +%unignore xla::swig::ComputationBuilder::Sign; +%unignore xla::swig::ComputationBuilder::Cos; +%unignore xla::swig::ComputationBuilder::Sin; +%unignore xla::swig::ComputationBuilder::Tanh; +%unignore xla::swig::ComputationBuilder::Atan2; +%unignore xla::swig::ComputationBuilder::IsFinite; +%unignore xla::swig::ComputationBuilder::Pow; +%unignore xla::swig::ComputationBuilder::Neg; +%unignore xla::swig::ComputationBuilder::Sort; +%unignore xla::swig::ComputationBuilder::SortKeyVal; +%unignore xla::swig::ComputationBuilder::Sqrt; +%unignore xla::swig::ComputationBuilder::Rsqrt; +%unignore xla::swig::ComputationBuilder::Square; +%unignore xla::swig::ComputationBuilder::Reciprocal; +%unignore xla::swig::ComputationBuilder::Erfc; +%unignore xla::swig::ComputationBuilder::Erf; +%unignore xla::swig::ComputationBuilder::ErfInv; +%unignore xla::swig::ComputationBuilder::Lgamma; +%unignore xla::swig::ComputationBuilder::Digamma; +%unignore xla::swig::ComputationBuilder::Acos; +%unignore xla::swig::ComputationBuilder::Asin; +%unignore xla::swig::ComputationBuilder::Atan; +%unignore xla::swig::ComputationBuilder::Tan; +%unignore xla::swig::ComputationBuilder::Acosh; +%unignore xla::swig::ComputationBuilder::Asinh; +%unignore xla::swig::ComputationBuilder::Atanh; +%unignore xla::swig::ComputationBuilder::Cosh; +%unignore xla::swig::ComputationBuilder::Sinh; +%unignore xla::swig::ComputationBuilder::Real; +%unignore xla::swig::ComputationBuilder::Imag; +%unignore xla::swig::ComputationBuilder::Conj; +%unignore xla::swig::ComputationBuilder::Complex; +%unignore xla::swig::ComputationBuilder::Cholesky; +%unignore xla::swig::ComputationBuilder::QR; +%unignore xla::swig::ComputationBuilder::TriangularSolve; +%unignore xla::swig::ComputationBuilder::CustomCall; +%unignore xla::swig::ComputationBuilder::Gather; +%unignore xla::swig::ComputationBuilder::Scatter; +%unignore xla::swig::DeleteComputation; %unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteCompiledLocalComputation; -%unignore xla::swig::DeleteCompiledXrtComputation; +%unignore xla::swig::DeleteLocalExecutable; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index aa692c786559f0d4e92e90e995aec51f394ca9c7..74f45b7cdcfd7d7b10a5832be37ac1fb34057743 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -567,6 +567,92 @@ PyObject* PyNumberToPyInt(PyObject* o) { } // namespace numpy +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } +#if PY_MAJOR_VERSION < 3 + if (!PyString_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); +#else + if (!PyBytes_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyBytes_AsString(attr)); +#endif + + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; +} + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 89861fc4f0165a1305537457ad1ca61f8e06839b..eff8cda334f00050605febad66a61aa1c518c500 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -136,6 +136,18 @@ PyObject* PyNumberToPyInt(PyObject* o); } // namespace numpy +// Miscellaneous swig helpers that don't have a better home. + +bool GetIntAttr(PyObject* o, const char* field, int64* result); + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f); + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field); + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d13bc73b25ea9bd06adb201457b271a07121ed22..d6a331fef10f018c6dc7df42a327d27bf4f06249 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -34,9 +34,16 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 +# Import the XRT backend, if available. +try: + # pylint: disable=g-import-not-at-top + from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api +except ImportError: + xrt_api = None + # Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and LocalComputation are CamelCase for +# method names of ComputationBuilder and Computation are CamelCase for # consistency with XLA. # pylint: disable=invalid-name @@ -50,7 +57,7 @@ from tensorflow.compiler.xla.service import hlo_pb2 # which case we need to be able to detect when incompatible versions are # installed. def version(): - return (0, 1, 7) + return (0, 1, 8) _OP_METADATA_FIELDS = [ @@ -66,6 +73,10 @@ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) class Backend(object): """Abstract base class for XLA backends.""" + @abc.abstractmethod + def device_count(self): + """Returns the number of devices known to the backend.""" + @abc.abstractmethod def buffer_from_pyval(self, pyval, device=0): """Allocates a fresh buffer and populates it with `pyval`.""" @@ -79,7 +90,8 @@ class Backend(object): """Destructures a tuple buffer into a sequence of buffers.""" @abc.abstractmethod - def compile(self, computation, argument_shapes, compile_options): + def compile(self, computation, argument_shapes, result_shape, + compile_options): """Compiles a computation. Returns an executable.""" @abc.abstractmethod @@ -95,25 +107,41 @@ class Backend(object): """Runs an executable in a replicated manner.""" +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + class XlaLocalBackend(Backend): """XLA backend implemented using the in-process xla::LocalClient API.""" + def __init__(self, platform=None): + platform = platform or _get_default_platform_name() + self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) + self._delete_buffer = c_api.DeleteLocalShapedBuffer + self._delete_executable = c_api.DeleteLocalExecutable + + def device_count(self): + return self.client.DeviceCount() + def buffer_from_pyval(self, pyval, device=0): - return c_api.LocalShapedBuffer.FromLiteral(pyval, None, device) + return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) def delete_buffer(self, c_buffer): - c_api.DeleteLocalShapedBuffer(c_buffer) + self._delete_buffer(c_buffer) def destructure_tuple(self, c_buffer): - result = c_api.DestructureLocalShapedBufferTuple(c_buffer) + result = c_buffer.DestructureTuple() return [result.Release(i) for i in xrange(result.size())] - def compile(self, c_computation, argument_shapes, compile_options): - return c_computation.Compile(argument_shapes, compile_options) + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return c_computation.Compile(argument_shapes, compile_options, self.client) def delete_executable(self, executable): - assert isinstance(executable, c_api.CompiledLocalComputation) - c_api.DeleteCompiledLocalComputation(executable) + self._delete_executable(executable) def execute(self, executable, args): return executable.Execute(args) @@ -129,29 +157,35 @@ class XrtBackend(Backend): def __init__(self, target): self.target = target + self._delete_buffer = xrt_api.DeleteXrtAllocation + self._delete_executable = xrt_api.DeleteXrtExecutable + + def device_count(self): + return 1 # Multidevice execution not implemented. def buffer_from_pyval(self, pyval, device=0): if device != 0: raise NotImplementedError( 'Multi-replica execution is not yet supported via the XRT backend.') - return c_api.XrtAllocation.FromLiteral(pyval, - _maybe_encode_string(self.target)) + return xrt_api.XrtAllocation.FromLiteral(pyval, + _maybe_encode_string(self.target)) def delete_buffer(self, c_buffer): - c_api.DeleteXrtAllocation(c_buffer) + self._delete_buffer(c_buffer) def destructure_tuple(self, c_buffer): - result = c_api.DestructureXrtAllocationTuple( + result = xrt_api.DestructureXrtAllocationTuple( c_buffer, _maybe_encode_string(self.target)) return [result.Release(i) for i in xrange(result.size())] - def compile(self, c_computation, argument_shapes, compile_options): - return c_computation.CompileForXrt(argument_shapes, - _maybe_encode_string(self.target)) + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return xrt_api.XrtExecutable.CompileForXrt( + c_computation.GetSerializedProto(), argument_shapes, result_shape, + _maybe_encode_string(self.target)) def delete_executable(self, executable): - assert isinstance(executable, c_api.CompiledXrtComputation) - c_api.DeleteCompiledXrtComputation(executable) + self._delete_executable(executable) def execute(self, executable, args): return executable.Execute(args) @@ -163,7 +197,20 @@ class XrtBackend(Backend): return [executable.Execute(per_replica_args[0])] -XLA_LOCAL_BACKEND = XlaLocalBackend() +_default_platform_name = 'Host' +_default_backend = None + + +def _get_default_platform_name(): + return _default_platform_name + + +def _get_default_local_backend(): + global _default_backend + global _default_platform_name + if _default_backend is None: + _default_backend = XlaLocalBackend(_default_platform_name) + return _default_backend class BackendType(enum.Enum): @@ -174,7 +221,7 @@ class BackendType(enum.Enum): def BackendSpec(backend, target): """Compatibility wrapper to support older clients. Do not use in new code.""" if backend == BackendType.XLA_LOCAL: - return XLA_LOCAL_BACKEND + return _get_default_local_backend() elif backend == BackendType.XRT: return XrtBackend(target) else: @@ -201,13 +248,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -346,22 +386,18 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend, replica): + def __init__(self, c_buffer, backend, device): self.c_buffer = c_buffer self._backend = backend - self._replica = replica + self._device = device @staticmethod - def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, device=0, backend=None): """Allocate and copy to XLA the given python value.""" + backend = backend or _get_default_local_backend() pyval = require_numpy_array_layout(pyval) - num_replicas = get_replica_count() - if not 0 <= replica < num_replicas: - raise ValueError( - 'Attempt to place buffer on replica {} when the replica count is {}' - .format(replica, num_replicas)) - cbuf = backend.buffer_from_pyval(pyval, replica) - return LocalBuffer(cbuf, backend, replica) + cbuf = backend.buffer_from_pyval(pyval, device) + return LocalBuffer(cbuf, backend, device) def to_py(self): return self.c_buffer.ToLiteral() @@ -369,8 +405,8 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) - def replica(self): - return self._replica + def device(self): + return self._device def delete(self): if self.c_buffer is not None: @@ -383,7 +419,7 @@ class LocalBuffer(object): result = self._backend.destructure_tuple(self.c_buffer) self.delete() return tuple( - LocalBuffer(sub_buffer, replica=self._replica, backend=self._backend) + LocalBuffer(sub_buffer, device=self._device, backend=self._backend) for sub_buffer in result) def is_deleted(self): @@ -595,7 +631,7 @@ class CompileOptions(object): self.num_replicas = get_replica_count() -def transfer_to_infeed(value, replica_number=None): +def transfer_to_infeed(value, device_ordinal=0): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -605,52 +641,50 @@ def transfer_to_infeed(value, replica_number=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - replica_number: the replica number to infeed the value to -- if not - provided, then the default replica (trivially replica 0) is used. + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. """ - if replica_number is None: - c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) - else: - c_api.TransferToInfeedLocalReplica( - require_numpy_array_layout(value), replica_number) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + backend.client.TransferToInfeed( + require_numpy_array_layout(value), device_ordinal) -def transfer_from_outfeed(shape, replica_number=None): - """Transfers a literal of the given shape from replica_number's outfeed. +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - replica_number: The replica number ordinal to transfer the outfeed value - from. (Each replica has a distinct outfeed queue.) + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + return backend.client.TransferFromOutfeed(shape, device_ordinal) -class LocalComputation(object): - """Python wrapper for a local XLA Computation. +class Computation(object): + """Python wrapper for an XLA Computation. - A LocalComputation can be executed if it is compiled. Otherwise, it - can still be used as a Computation where required by the - ComputationBuilder methods. + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + def __init__(self, c_computation, backend=None): self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. self._backend = backend - self._is_compiled = is_compiled + self._delete_computation = c_api.DeleteComputation @property def computation(self): - if self._is_compiled: - raise ValueError( - 'Attempt to read the XLA computation of a compiled LocalComputation.') return self._c_computation def GetProto(self): - """Get the HloModuleProto proto object in this local computation. + """Get the HloModuleProto proto object in this computation. Returns: An HloModuleProto proto object that has the whole-graph information. @@ -659,30 +693,41 @@ class LocalComputation(object): proto = hlo_pb2.HloModuleProto.FromString(serialized) return proto - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): - """Compiles an un-compiled local computation. + def GetHloText(self): + """Get the textual HLO representation of this computation. + + Returns: + A string containing the textual HLO. + """ + return self.computation.GetHloText() + + def GetHloDotGraph(self): + """Get a Graphviz Dot representation of this computation. + + Returns: + A string containing the graphviz dot graph. + """ + return self.computation.GetHloDotGraph() - Local computations are the result of a "LocalComputationBuild'ing" process - -- they start in uncompiled form, and via a call to Compile() turn into a - compiled local computation. + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, + backend=None): + """Compiles a computation. - Raises: - ValueError: if this is already a compiled local computation. + Computations are the result of a "ComputationBuild'ing" process. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. - compile_options: options to use for compilation, includes an optional - laid out result shape for the computation. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. + backend: a `Backend` for which an executable should be generated. Returns: - A newly *compiled* local computation instance. + A Executable instance. """ - if self._is_compiled: - raise ValueError('Attempt to compile a compiled local XLA computation.') - + backend = backend or self._backend or _get_default_local_backend() result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: @@ -695,18 +740,20 @@ class LocalComputation(object): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - c = self._backend.compile(self.computation, argument_shapes, - compile_options) - return LocalComputation(c, is_compiled=True, backend=self._backend) + c = backend.compile(self.computation, argument_shapes, result_shape, + compile_options) + return Executable(c, backend=backend) def CompileWithExampleArguments(self, arguments=(), compile_options=None, - layout_fn=None): + layout_fn=None, + backend=None): return self.Compile( argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, - layout_fn=layout_fn) + layout_fn=layout_fn, + backend=backend) def GetProgramShape(self): (arg_shapes, result_shape) = self._c_computation.GetProgramShape() @@ -716,13 +763,31 @@ class LocalComputation(object): def GetReturnValueShape(self): return _wrap_shape(self._c_computation.GetReturnValueShape()) + def __del__(self): + if self._c_computation: + self._delete_computation(self._c_computation) + + +class Executable(object): + """Python wrapper for an XLA Executable.""" + + def __init__(self, c_executable, backend=None): + self._c_executable = c_executable + self._device_ordinals = c_executable.DeviceOrdinals() + self._backend = backend + + def DeviceOrdinals(self): + """Returns a list containing the device ordinals for each replica.""" + return self._device_ordinals + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): raise ValueError('Executing with deleted local buffer argument') raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._backend.execute(self._c_computation, raw_args) - return LocalBuffer(output_buffer, backend=self._backend, replica=0) + output_buffer = self._backend.execute(self._c_executable, raw_args) + return LocalBuffer( + output_buffer, backend=self._backend, device=self._device_ordinals[0]) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -732,14 +797,12 @@ class LocalComputation(object): sequence comprises the arguments for execution on the i'th replica. Returns: - A list of the computation's outputs on each replica, as a LocalBuffer. If + A list of the computation's outputs for each replica, as a LocalBuffer. If a shallow sequence of arguments was passed in for `arguments`, then the sole, zero'th replica's output is returned instead, as a LocalBuffer. """ - if not self._is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') if arguments is None: - arguments = ((),) * get_replica_count() + arguments = ((),) * len(self._device_ordinals) else: arguments = [list(replica_args) for replica_args in arguments] @@ -748,30 +811,35 @@ class LocalComputation(object): for arg in replica_args: if arg.is_deleted(): raise ValueError('Executing with deleted local buffer argument') - if arg.replica() != replica: + if arg.device() != self._device_ordinals[replica]: raise ValueError( - 'Executing on replica {} with argument from replica {}'.format( - replica, arg.replica())) + 'Executing on device {} with argument from device {}'.format( + self._device_ordinals[replica], arg.device())) # Pull out argument buffer handles + # pylint: disable=g-complex-comprehension stripped_args = [ [arg.c_buffer for arg in replica_args] for replica_args in arguments ] # Execute - output_buffers = self._backend.execute_replicated( - self._c_computation, stripped_args) + output_buffers = self._backend.execute_replicated(self._c_executable, + stripped_args) # Wrap output handles in LocalBuffer instances return tuple( - LocalBuffer(output_buffer, backend=self._backend, replica=replica) + LocalBuffer( + output_buffer, + backend=self._backend, + device=self._device_ordinals[replica]) for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): """Execute on one replica with Python values as arguments and output.""" def put(arg): - return LocalBuffer.from_pyval(arg, backend=self._backend) + return LocalBuffer.from_pyval( + arg, device=self._device_ordinals[0], backend=self._backend) arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() @@ -779,22 +847,19 @@ class LocalComputation(object): def ExecuteWithPythonValuesPerReplica(self, arguments): """Execute on many replicas with Python values as arguments and output.""" - def put(arg, replica): - return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + def put(arg, device): + return LocalBuffer.from_pyval(arg, device, backend=self._backend) - arguments = [[put(arg, replica) - for arg in replica_args] - for replica, replica_args in enumerate(arguments)] + # pylint: disable=g-complex-comprehension + arguments = [[ + put(arg, self._device_ordinals[replica]) for arg in replica_args + ] for replica, replica_args in enumerate(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): # Python may have freed c_api first. - if c_api and self._c_computation: - if self._is_compiled: - self._backend.delete_executable(self._c_computation) - else: - assert isinstance(self._c_computation, c_api.LocalComputation) - c_api.DeleteLocalComputation(self._c_computation) + if c_api and self._c_executable: + self._backend.delete_executable(self._c_executable) def _make_replica_group_proto(replica_group): @@ -807,8 +872,8 @@ class ComputationBuilder(object): """XLA computation builder. Enqueues XLA ops in sequence and in order to build a - LocalComputation, which in turn can be compiled into a - CompiledLocalComputation, which in turn can be locally executed. + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. """ # The methods of this class map 1-to-1 onto the XLA C++ @@ -819,16 +884,23 @@ class ComputationBuilder(object): # pylint: disable=g-doc-args def __init__(self, name): - self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._client = c_api.ComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + backend: deprecated. Pass a `backend` to `Computation.Compile` instead. + Returns: + A `Computation`. + """ if root is not None: - return LocalComputation( - self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + return Computation(self._client.BuildWithRoot(root), backend=backend) else: - return LocalComputation( - self._client.Build(), is_compiled=False, backend=backend) + return Computation(self._client.Build(), backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -1480,7 +1552,7 @@ class ComputationBuilder(object): Args: operand: a LocalOp to test. - Returns: a LocalComputation that is rooted on the given `operand` which is a + Returns: a Computation that is rooted on the given `operand` which is a compile-time constant. """ return self._client.BuildConstantSubGraph(operand) @@ -1681,7 +1753,7 @@ def _forward_methods_to_local_builder(): Set up methods, corresponding to unary and binary XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - LocalComputationBuilder C-extension API. + ComputationBuilder C-extension API. """ def forward_to_local_builder_with_handles(target_method, is_binop=False): @@ -1701,13 +1773,13 @@ def _forward_methods_to_local_builder(): for method_name in _UNARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name)) + getattr(c_api.ComputationBuilder, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) for method_name in _BINARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + getattr(c_api.ComputationBuilder, method_name), is_binop=True) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) @@ -1715,8 +1787,14 @@ def _forward_methods_to_local_builder(): _forward_methods_to_local_builder() +_default_replica_count = 1 + + def initialize_replica_count(replica_count): - """Initializes the desired replica count to use on XLA service init. + """Initializes the default replica count to use. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. Args: replica_count: number of replicas that are desired for set up during XLA @@ -1725,31 +1803,30 @@ def initialize_replica_count(replica_count): Raises: A runtime exception if the XLA service has already been initialized. """ - c_api.InitializeReplicaCount(replica_count) - + global _default_replica_count + _default_replica_count = replica_count -def initialize_platform_name(platform_name): - """Initializes the desired platform name to use on XLA service init. - Args: - platform_name: string name of platform. +def get_replica_count(): + """Returns the default replica count. - Raises: - A runtime exception if the XLA service has already been initialized. - A runtime exception if the platform does not exist, or there are no devices - with that platform. + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. """ - platform_name = _maybe_encode_string(platform_name) - c_api.InitializePlatformName(platform_name) + return _default_replica_count -def get_replica_count(): - """Returns the current replica count used for the XLA service. +def initialize_platform_name(platform_name): + """Initializes the default platform name to use for XLA. - Note: this will return a value whether the XLA service has been initialized - yet or not. + Args: + platform_name: string name of platform. """ - return c_api.GetReplicaCount() + global _default_platform_name + _default_platform_name = platform_name + + # Make sure the platform is valid by trying to instantiate it. + _get_default_local_backend() def register_cpu_custom_call_target(name, fn): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index aa38c06cf908079e627156f51264965892de7ff0..45ed209c992339a766afaca478f140f57640324e 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -29,7 +29,7 @@ from tensorflow.compiler.xla.python import xla_client import unittest -class LocalComputationTest(unittest.TestCase): +class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -85,7 +85,27 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationsWithConstantsTest(LocalComputationTest): +class ComputationPrinting(unittest.TestCase): + + def ExampleComputation(self): + builder = xla_client.ComputationBuilder("acomputation") + p0 = builder.ParameterFromNumpy(np.float32(0)) + p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder.Mul(p0, p1) + return builder.Build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.GetHloText() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.GetHloDotGraph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" def testConstantScalarSumS8(self): @@ -304,7 +324,7 @@ class ComputationsWithConstantsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=0.75) -class ParametersTest(LocalComputationTest): +class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" def setUp(self): @@ -384,7 +404,7 @@ class ParametersTest(LocalComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(LocalComputationTest): +class LocalBufferTest(ComputationTest): """Tests focusing on execution with LocalBuffers.""" def _Execute(self, c, arguments): @@ -482,7 +502,7 @@ class LocalBufferTest(LocalComputationTest): self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) -class SingleOpTest(LocalComputationTest): +class SingleOpTest(ComputationTest): """Tests for single ops. The goal here is smoke testing - to exercise the most basic functionality of @@ -1175,7 +1195,7 @@ class SingleOpTest(LocalComputationTest): np.testing.assert_allclose(g, expected, rtol=1e-4) -class EmbeddedComputationsTest(LocalComputationTest): +class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" def _CreateConstantS32Computation(self): @@ -1639,7 +1659,7 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=expected) -class ErrorTest(LocalComputationTest): +class ErrorTest(ComputationTest): def setUp(self): self.f32_scalar_2 = NumpyArrayF32(2.0) @@ -1656,7 +1676,7 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) -class ComputationRootTest(LocalComputationTest): +class ComputationRootTest(ComputationTest): """Tests related to setting the root of the computation.""" def testComputationRootDifferentFromLastOp(self): diff --git a/tensorflow/compiler/xla/python/xla_data.i b/tensorflow/compiler/xla/python/xla_data.i new file mode 100644 index 0000000000000000000000000000000000000000..974f314af24f61c0015a8d51c16dff1bfc84c7cc --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_data.i @@ -0,0 +1,654 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// xla_data.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// Span <- sequence of int +// vector -> sequence of int +// Span <- sequence of LocalOp +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) +// std::vector <- sequence of xla_client.Shape objects +// PrimitiveType <- int +// Span> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto +// ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// Shapes output by C++ become Python objects with the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. +// +// Other SWIG object wrappers (e.g. of Computation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%module(threads="1") xla_data + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Basic types + + +%typemap(out) std::vector { + PyObject* out = PyList_New($1.size()); + for (int i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); + } + $result = out; +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + SWIG_fail; + } + Py_INCREF(Py_None); + $result = Py_None; +} + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + SWIG_fail; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + SWIG_fail; + } + $1 = &literal_status.ValueOrDie(); +} + +%typemap(out) Literal (StatusOr obj_status) { + obj_status = numpy::PyObjectFromXlaLiteral(*$1); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(out) StatusOr (StatusOr obj_status) { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } + obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + SWIG_fail; + } + temps.push_back(literal_status.ConsumeValueOrDie()); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyProgramShapeInfoFromXlaProgramShape( + $1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) const Shape& (Shape temp) { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +%typemap(in) const absl::optional& ( + absl::optional temp) { + if ($input == Py_None) { + temp = absl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(absl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + SWIG_fail; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + SWIG_fail; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + SWIG_fail; + } + $1 = static_cast(value); +} + +// Span> + +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + SWIG_fail; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + SWIG_fail; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + SWIG_fail; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + SWIG_fail; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + SWIG_fail; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + SWIG_fail; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c55abc17f87c369e3d5b2140a84014e07921a9a --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -0,0 +1,297 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/xrt.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace swig { + +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} + +XrtAllocation::~XrtAllocation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +/* static */ +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = argument.ToProto(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto literal_string = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({literal_string, alloc.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtAllocation(handle, argument.shape(), session_target); +} + +const int64 XrtAllocation::handle() const { return handle_; } + +const Shape& XrtAllocation::shape() const { return shape_; } + +StatusOr XrtAllocation::ToLiteral() const { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + + xla::LiteralProto response; + TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); + return Literal::CreateFromProto(response); +} + +XrtAllocationTuple::XrtAllocationTuple(std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + CHECK(element != nullptr); + } +} + +XrtAllocationTuple::~XrtAllocationTuple() { + for (XrtAllocation* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr XrtAllocationTuple::Release(int i) { + XrtAllocation* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int64 XrtAllocationTuple::size() const { return elements_.size(); } + +StatusOr XrtExecutable::CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto compile = tensorflow::ops::XRTCompile(root, program); + TF_RETURN_IF_ERROR(root.status()); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + ProgramShape program_shape; + for (auto& shape : argument_shapes) { + *program_shape.add_parameters() = shape; + } + *program_shape.mutable_result() = result_shape; + + LayoutUtil::SetToDefaultLayout(&program_shape); + *config->mutable_program_shape() = program_shape.ToProto(); + c.mutable_hlo_snapshot() + ->mutable_hlo() + ->mutable_hlo_module() + ->ParsePartialFromString(hlo_module_proto); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({program, c.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtExecutable(program_shape, handle, session_target); +} + +XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} + +XrtExecutable::~XrtExecutable() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({computation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +StatusOr XrtExecutable::Execute( + absl::Span argument_handles) { + const int num_expected_arguments = program_shape().parameters().size(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + std::vector arguments; + arguments.reserve(num_expected_arguments); + for (int i = 0; i < num_expected_arguments; ++i) { + arguments.push_back( + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); + } + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto execution_config = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto execute = tensorflow::ops::XRTExecute(root, computation_handle, + execution_config, arguments); + TF_RETURN_IF_ERROR(root.status()); + + TF_RET_CHECK(argument_handles.size() == arguments.size()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + for (int i = 0; i < arguments.size(); ++i) { + inputs.insert({arguments[i], argument_handles[i]->handle()}); + } + inputs.insert({computation_handle, handle()}); + inputs.insert({execution_config, e.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); + + int64 output = outputs[0].scalar()(); + return new XrtAllocation(output, program_shape().result(), session_target_); +} + +const ProgramShape& XrtExecutable::program_shape() const { + return program_shape_; +} + +int64 XrtExecutable::handle() const { return handle_; } + +void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } + +void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } + +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target) { + const Shape& tuple_shape = allocation->shape(); + + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); + auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + std::vector results; + for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + inputs.clear(); + inputs.insert({base_handle, allocation->handle()}); + inputs.insert({shape_index, {i}}); + std::vector outputs; + auto status = session.Run(inputs, {subtuple}, &outputs); + if (!status.ok()) { + // Clean up before returning non-ok status. + for (int j = 0; j < results.size(); ++j) { + delete results[j]; + } + return status; + } + const int64 subtuple_handle = outputs[0].scalar()(); + const Shape& subtuple_shape = + ShapeUtil::GetTupleElementShape(tuple_shape, i); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); + } + return new XrtAllocationTuple(std::move(results)); +} + +} // namespace swig +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xrt.h b/tensorflow/compiler/xla/python/xrt.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5bba6d5c9641dadc323f70745e870c14543321 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" + +namespace xla { +namespace swig { + +// Represents a reference to literals that live in a device-allocated buffer via +// XRT. Specifically, wraps an int64 handle produced by running the allocation +// graph, and an XLA shape to track the referent's shape. +class XrtAllocation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); + + XrtAllocation(int64 handle, Shape shape, const string& session_target); + ~XrtAllocation(); + StatusOr ToLiteral() const; + const Shape& shape() const; + const int64 handle() const; + + private: + const int64 handle_; + const Shape shape_; + const string session_target_; +}; + +// Result of a tuple destructuring operation on an XrtAllocation. +class XrtAllocationTuple { + public: + // Note: any XrtAllocation elements that are not Release()'d will be + // deallocated in the destructor. + explicit XrtAllocationTuple(std::vector elements); + + ~XrtAllocationTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int64 size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued XrtAllocation into its constitutent elements +// in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target); + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XRT computation handle. +class XrtExecutable { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. + static StatusOr CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target); + + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~XrtExecutable(); + + std::vector DeviceOrdinals() const { return {0}; } + + StatusOr Execute( + absl::Span argument_handles); + + const ProgramShape& program_shape() const; + int64 handle() const; + + private: + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; +}; + +// Functions for freeing resources from the Python side. +void DeleteXrtAllocation(XrtAllocation* allocation); +void DeleteXrtExecutable(XrtExecutable* computation); + +} // namespace swig +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ diff --git a/tensorflow/compiler/xla/python/xrt.i b/tensorflow/compiler/xla/python/xrt.i new file mode 100644 index 0000000000000000000000000000000000000000..456dd7be86e479b46815fc16b51a10431fe2060d --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.i @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wrappers for XRT ops. + +%module(threads="1") xrt + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" + +%{ +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/xrt.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Computation and buffer/allocation types + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtExecutable*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocationTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + XrtAllocation* xrta; + if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; + } + temps.push_back(xrta); + Py_DECREF(o); + } + $1 = temps; +} + + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::XrtAllocation; +%unignore xla::swig::XrtAllocation::FromLiteral; +%unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; +%unignore xla::swig::XrtAllocationTuple; +%unignore xla::swig::XrtAllocationTuple::Release; +%unignore xla::swig::XrtAllocationTuple::size; +%unignore xla::swig::XrtExecutable; +%unignore xla::swig::XrtExecutable::CompileForXrt; +%unignore xla::swig::XrtExecutable::DeviceOrdinals; +%unignore xla::swig::XrtExecutable::Execute; +%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::DeleteXrtAllocation; +%unignore xla::swig::DeleteXrtExecutable; + +%thread; +%include "tensorflow/compiler/xla/python/xrt.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a5eae6d3962255d25a72362df45f0d8af52b1011..8d8394cb43ee013b9396a54e3a4d037445fcc0e1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -680,7 +680,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -2204,6 +2203,8 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", @@ -3232,33 +3233,6 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_tfgraph_builder", - srcs = ["hlo_tfgraph_builder.cc"], - hdrs = ["hlo_tfgraph_builder.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "hlo_tfgraph_builder_test", - srcs = ["hlo_tfgraph_builder_test.cc"], - deps = [ - ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - ], -) - cc_library( name = "hlo_graph_dumper", srcs = [ @@ -3270,7 +3244,6 @@ cc_library( ":hlo", ":hlo_casting_utils", ":hlo_execution_profile", - ":hlo_tfgraph_builder", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -3529,6 +3502,37 @@ tf_cc_test( ], ) +cc_library( + name = "stable_sort_expander", + srcs = ["stable_sort_expander.cc"], + hdrs = ["stable_sort_expander.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "stable_sort_expander_test", + srcs = ["stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":stable_sort_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c5deb74e96ad5d9ff62d1407f84064dad63e61fb..bd17e96106abd9de0dd3bbf418439b0fb3edb746 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -280,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { hlo)); } - // Helper method to perform and add reduction in a single dimension. - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the msot major, + // and the contracting dimensions are most minor. + StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions) { + std::vector transpose_dimensions(batch_dimensions.begin(), + batch_dimensions.end()); + for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { + if (!(absl::c_linear_search(batch_dimensions, i) || + absl::c_linear_search(contracting_dimensions, i))) { + transpose_dimensions.push_back(i); + } + } + transpose_dimensions.insert(transpose_dimensions.end(), + contracting_dimensions.begin(), + contracting_dimensions.end()); + return MakeTransposeHlo(dot_operand, transpose_dimensions); + } + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); - Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + Shape shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, + hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( - shape, hlo, zero, {dim}, AddReduce_computation)); + shape, hlo, zero, dims, AddReduce_computation)); + } + + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + return AddReduce(hlo, std::vector{dim}); } // Convenience method for replacing an instruction with a bitcast. If operand @@ -1120,16 +1156,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( std::swap(rhs_collapsing_dim, rhs_kept_dim); } - auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); - }; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = as_type(hlo, dot->shape().element_type()); + hlo = AsType(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); @@ -1138,7 +1166,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(as_type(hlo, F32), dim); + return AddReduce(AsType(hlo, F32), dim); }; auto broadcast = [&](HloInstruction* hlo, const Shape& shape, @@ -1247,8 +1275,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return dims; }; - // If the contracting dimension is 1, remove the degnerate dimnesions from the - // lhs and rhs, broadcast each to the result shape and multiply. + // If the contracting dimension is 1, remove the degnerate dimnensions from + // the lhs and rhs, broadcast each to the result shape and multiply. if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && (rhs_kept_dim == rhs_rank - 1 || (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { @@ -1608,34 +1636,26 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If there are no contracting dimensions, a dot can be rewritten as // mul(broadcast(transpose(x)),broadcast(transpose(y))) if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { - std::vector lhs_transpose( - dot->dot_dimension_numbers().lhs_batch_dimensions().begin(), - dot->dot_dimension_numbers().lhs_batch_dimensions().end()); - for (int64 i = 0; i < lhs->shape().rank(); ++i) { - if (!absl::c_linear_search( - dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) { - lhs_transpose.push_back(i); - } - } - TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, - MakeTransposeHlo(lhs, lhs_transpose)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); if (dot->shape().rank() != lhs->shape().rank()) { std::vector lhs_broadcast_dims(lhs->shape().rank()); absl::c_iota(lhs_broadcast_dims, 0); new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( dot->shape(), new_lhs, lhs_broadcast_dims)); } - std::vector rhs_transpose( - dot->dot_dimension_numbers().rhs_batch_dimensions().begin(), - dot->dot_dimension_numbers().rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs->shape().rank(); ++i) { - if (!absl::c_linear_search( - dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) { - rhs_transpose.push_back(i); - } - } - TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, - MakeTransposeHlo(rhs, rhs_transpose)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); if (dot->shape().rank() != rhs->shape().rank()) { std::vector rhs_broadcast_dims( dot->dot_dimension_numbers().lhs_batch_dimensions_size()); @@ -1651,6 +1671,78 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { new_lhs, new_rhs)); } + // If the lhs or rhs have only batch and contracting dimensions, a dot can be + // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + + int64 lhs_outer_dims = + lhs->shape().rank() - + (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + int64 rhs_outer_dims = + rhs->shape().rank() - + (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); + CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); + if (rhs_outer_dims > 0) { + std::vector lhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(lhs_broadcast_dims, 0); + lhs_broadcast_dims.resize(lhs->shape().rank()); + std::iota(lhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().lhs_batch_dimensions_size(), + lhs_broadcast_dims.end(), + dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + rhs_outer_dims); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_rhs->shape(), new_lhs, lhs_broadcast_dims)); + } else if (lhs_outer_dims > 0) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().rhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + rhs_broadcast_dims.resize(rhs->shape().rank()); + std::iota(rhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size(), + rhs_broadcast_dims.end(), + dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + lhs_outer_dims); + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_lhs->shape(), new_rhs, rhs_broadcast_dims)); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); + std::vector reduce_dims( + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + new_dot = AsType(new_dot, F32); + const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); + absl::c_iota( + reduce_dims, + outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + new_dot = AddReduce(new_dot, reduce_dims); + new_dot = AsType(new_dot, dot->shape().element_type()); + return ReplaceInstruction(dot, new_dot); + } + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && @@ -2583,11 +2675,11 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( int64 start = slice->slice_starts(i); int64 low = padding_config.dimensions(i).edge_padding_low(); int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (start < low || start >= low + data) { + return true; } } - return true; + return false; }(); if (in_padding) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index feb6a0fb79538b38afa1110296b52061ec7f2259..af03fcb100813e8942efcaefc296b971c01a6aaa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2753,8 +2753,9 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - TF_ASSERT_OK( - MakeSortHlo(keys_shape, {keys}, 0, &builder, module.get()).status()); + TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder, + module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2775,7 +2776,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { HloInstruction::CreateParameter(2, values_shape, "values1")); TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( {keys_shape, values_shape, values_shape}), - {keys, values0, values1}, 0, &builder, module.get()) + {keys, values0, values1}, 0, /*is_stable=*/false, + &builder, module.get()) .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); @@ -3712,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); @@ -3958,7 +3960,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { param = f32[3,4] parameter(0) constant = f32[] constant(0.0) pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 - ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -3969,6 +3971,27 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { const char* hlo_string = R"( HloModule module @@ -3990,6 +4013,29 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { EXPECT_THAT(root, GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { + const char* hlo_string = R"( + HloModule module + + ENTRY entry () -> f32[1]{0} { + constant.val = f32[] constant(4) + constant.pad = f32[] constant(-7) + reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val) + pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0 + slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]} + ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { const char* hlo_string = R"( HloModule module @@ -4220,12 +4266,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { int m, k, n; PrimitiveType element_type; std::tie(m, k, n, element_type) = GetParam(); - - Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); - Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}) - : ShapeUtil::MakeShape(element_type, {1, 3, 5, m}); - Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}) - : ShapeUtil::MakeShape(element_type, {1, 3, 5, n}); + std::vector lhs_dims = {1, 3, 5}; + std::vector rhs_dims = lhs_dims; + std::vector output_dims = lhs_dims; + if (m > 0) { + lhs_dims.push_back(m); + output_dims.push_back(m); + } + if (k > 0) { + lhs_dims.push_back(k); + rhs_dims.push_back(k); + } + if (n > 0) { + rhs_dims.push_back(n); + output_dims.push_back(n); + } + Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction( @@ -4240,7 +4298,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(2); if (k > 0) { - dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3); dot_dnums.add_rhs_contracting_dimensions(3); } builder.AddInstruction(HloInstruction::CreateDot( @@ -4248,9 +4306,9 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); - const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1; - const bool computation_should_be_modified = dot_should_be_transformed; - EXPECT_EQ(changed, computation_should_be_modified); + const bool dot_should_be_transformed = + m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; + EXPECT_EQ(changed, dot_should_be_transformed); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4261,10 +4319,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_SUITE_P( - BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2), - ::testing::Values(1, 2), ::testing::Values(F32, BF16))); +INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, + BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(F32, BF16))); class DotStrengthReductionTest : public AlgebraicSimplifierTest, diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 215e8ced4bb3f98a26ac4eb9912a7fd4d917852f..d016d3e03d5e994841b81cda6214b6ff7cb550be 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" @@ -67,18 +66,38 @@ const absl::optional>& BackendOptions::allowed_devices() const { return allowed_devices_; } +namespace { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool) + : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + tensorflow::thread::ThreadPool* pool_ = nullptr; +}; + +} // namespace + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. -struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper(const int num_threads) +struct Backend::IntraOpThreadPool { + explicit IntraOpThreadPool(const int num_threads) : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), "XLAEigen", num_threads)), - wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + wrapper(new EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; + std::unique_ptr wrapper; std::unique_ptr device; }; @@ -146,8 +165,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); - intra_op_thread_pool_wrapper_.reset( - new EigenThreadPoolWrapper(num_threads)); + intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads)); } } @@ -159,17 +177,17 @@ int Backend::default_device_ordinal() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->device.get(); + return intra_op_thread_pool_->device.get(); } tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->pool.get(); + return intra_op_thread_pool_->pool.get(); } StatusOr Backend::stream_executor( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index c35f033dc0180409ae3888c2050021da83f5c72a..e7f29a044b95015aa7e547373c24971646833280 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -156,7 +156,6 @@ class Backend { Status ResetDevices(); private: - struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, absl::Span stream_executors, TransferManager* transfer_manager, @@ -183,7 +182,8 @@ class Backend { std::unique_ptr memory_allocator_; // For the CPU backend, an Eigen threadpool device for use by Eigen code. - std::unique_ptr intra_op_thread_pool_wrapper_; + struct IntraOpThreadPool; + std::unique_ptr intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index eda026ac5685dc469a6230094eb28b3618e36400..dbabd82dd55465dd4c85a56aea849a3e3702d6bf 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( *rhs = batch_dot->mutable_operand(1); const Shape& lhs_shape = lhs->shape(); + // A dot with no contracting dims will be rewritten into a multiply by + // AlgebraicSimplifier. Dots with multiple contracting dims are currently + // unsupported. + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return false; + } + std::vector degenerate_dims; for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { if (lhs_shape.dimensions(batch_dim) == 1) { diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 52ec1a794c5e9f4452a4bf2b648f453d8acfe976..a81f394a38f091b89b7f1e4d26653ff549f35b75 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -169,5 +169,47 @@ main { /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); } +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsNonContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,101] parameter(0) + b = f32[1,101] parameter(1) + ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0}, + lhs_contracting_dims={}, + rhs_batch_dims={0}, + rhs_contracting_dims={} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsMultipleContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + lhs = f32[1,5,17,10,13] parameter(0) + rhs = f32[1,9,10,13,6,5] parameter(1) + ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0}, + rhs_batch_dims={2,0}, + lhs_contracting_dims={1,4}, + rhs_contracting_dims={5,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 2591ff602c8467afef9a1cbacd9aff2e63a8457e..2caa979745b3b40817acb1b6951e1de5ffa294a4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -286,7 +286,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), - {key, value}, 0, &builder, module.get())); + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -314,7 +315,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), - {key, value}, 0, &builder, module.get())); + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index c02ffda575278905f6549b362e5e7d94f5713b36..57a636fd740995d6cce933fe19d5592a64bde5cf 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -30,7 +30,7 @@ namespace xla { // The context in which a computation is called by another computation. enum class CallContext { - // In a parallel contex the computation is applied to each element of the + // In a parallel context the computation is applied to each element of the // array argument(s). kMap and kReduce instructions call computations in // parallel context. kParallel, diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl index e78330b21689fdd818cd97128bbcaaa9e0118602..ffa1cd4ec8e26e7dbe92e7b99cf65e99db5400b9 100644 --- a/tensorflow/compiler/xla/service/cpu/build_defs.bzl +++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl @@ -1,12 +1,11 @@ """build_defs for service/cpu.""" - def runtime_copts(): - """Returns copts used for CPU runtime libraries.""" - return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ - "//tensorflow:android_arm": ["-mfpu=neon"], - "//conditions:default": [] - }) + select({ - "//tensorflow:android": ["-O2"], - "//conditions:default": [] - })) + """Returns copts used for CPU runtime libraries.""" + return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ + "//tensorflow:android_arm": ["-mfpu=neon"], + "//conditions:default": [], + }) + select({ + "//tensorflow:android": ["-O2"], + "//conditions:default": [], + })) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 5abb3eb38725a3d8e2b761abff1b66f35e92c130..9967cf28ee2389f9bef9780d2c986140f9bf2682 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -583,7 +583,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { b_.getVoidTy(), {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo(), b_.getInt8PtrTy(), + b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); auto* key_value_sort_func = llvm::dyn_cast( @@ -616,8 +616,8 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, b_.getInt32(sort->operand_count()), sizes, - GetExecutableRunOptionsArgument(), GetProfileCountersArgument(), - less_than_function}); + b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), + GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index cb46674138acf8e1daea24102f988a9a355ec5c8..70a6d0af02c0c2db7208db561cf29e35a74707b2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -32,8 +32,8 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes, char* run_options, - int64* prof_counters, + int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, int64* prof_counters, void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. @@ -69,22 +69,27 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - std::stable_sort( - indices.get(), indices.get() + sort_dimension_elements, - [&](int64 a, int64 b) -> bool { - int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - for (int32 i = 0; i < values_count; ++i) { - comparison_values[i * 2] = values[i] + memory_index_lhs; - comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; - } - char result = 0; // Overwritten by less_than. - less_than(&result, run_options, comparison_values.get(), nullptr, - prof_counters); - return result != 0u; - }); + auto compare_function = [&](int64 a, int64 b) -> bool { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + for (int32 i = 0; i < values_count; ++i) { + comparison_values[i * 2] = values[i] + memory_index_lhs; + comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; + } + char result = 0; // Overwritten by less_than. + less_than(&result, run_options, comparison_values.get(), nullptr, + prof_counters); + return result != 0u; + }; + if (is_stable) { + std::stable_sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } else { + std::sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 4813de9ee67282110959f943cd70fcfe2ca94d9d..50c2911c3bd392b6df12717c34d250ce86ad26e0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -22,15 +22,14 @@ limitations under the License. extern "C" { // Each entry in 'values' represents a 3-dimensional shape with dimensions -// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending -// order according to the results of comparisons using the provided 'less_than' +// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order +// according to the results of comparisons using the provided 'less_than' // function. 'values_count' must be > 0 and specifies the number of entries in // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' -// bytes. The elements in each 'values' shape are reordered in the same way -// according to the comparisons using the first shape. 'run_options' and -// 'prof_counters' are passed through to the less-than function, which expects -// the following arguments: +// bytes. 'is_stable' specifies whether the sorting should be stable. +// 'run_options' and 'prof_counters' are passed through to the less-than +// function, which expects the following arguments: // - pointer to the return value buffer (char*) // - xla::ExecutableRunOptions = 'run_options' (char*) // - pointers to the parameter buffers (char**) @@ -39,8 +38,8 @@ extern "C" { extern void __xla_cpu_runtime_KeyValueSort( tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes, char* run_options, - tensorflow::int64* prof_counters, + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); } diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc index e54f205465e0f453766c61f10102e104ee2cf5a6..9fc472ff767441e60cf618ac9022e5c50ea20023 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -948,15 +948,16 @@ llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { return type->getPointerTo(); } -struct GemvInputsWithCanonicalType { +struct GemvBuffersWithCanonicalType { llvm::Value* lhs_canonicalized; llvm::Value* rhs_canonicalized; llvm::Value* addend_canonicalized; + llvm::Value* result_canonicalized; }; -GemvInputsWithCanonicalType GetGemvInputsWithCanonicalType( +GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::IRBuilder<>* b) { + llvm::Value* result, llvm::IRBuilder<>* b) { // We characterize a GEMV operation via M and K, since N is implicitly 1. // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented // by the same GEMV that multiplies [5,6] with [1,6]. However, the @@ -965,20 +966,23 @@ GemvInputsWithCanonicalType GetGemvInputsWithCanonicalType( // from the `xla::Shape`s. Since we want to be able to call the same // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV // inputs here into the same type. - GemvInputsWithCanonicalType result; + GemvBuffersWithCanonicalType buffers_with_canonical_type; llvm::Type* lhs_type = lhs->getType(); llvm::Type* rhs_type = rhs->getType(); llvm::Type* addend_type = addend ? addend->getType() : nullptr; + llvm::Type* result_type = result->getType(); - result.lhs_canonicalized = + buffers_with_canonical_type.lhs_canonicalized = b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); - result.rhs_canonicalized = + buffers_with_canonical_type.rhs_canonicalized = b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); - result.addend_canonicalized = + buffers_with_canonical_type.addend_canonicalized = addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) : nullptr; + buffers_with_canonical_type.result_canonicalized = + b->CreateBitCast(result, GetPointerToElementType(result_type)); - return result; + return buffers_with_canonical_type; } } // namespace @@ -993,14 +997,15 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); - GemvInputsWithCanonicalType canonical_inputs = - GetGemvInputsWithCanonicalType(lhs, rhs, addend, b); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, - canonical_inputs.addend_canonicalized, result, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result) { @@ -1020,14 +1025,15 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); - GemvInputsWithCanonicalType canonical_inputs = - GetGemvInputsWithCanonicalType(lhs, rhs, addend, b); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, - canonical_inputs.addend_canonicalized, result, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index e868dc6d889c867001bf2a145bb9277c56950401..808929be75ec6fd0cfb15418a231431b8d51e089 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1367,26 +1367,69 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm::Type* raw_value_ty = raw_value->getType(); - // Convert raw integer to float in range [0, 1) if the element is a float. + // If we're generating a floating-point value, convert the raw integer R (i.e. + // `raw_value`) to a float in the range [0, 1). + // + // The basic approach is to choose a significand and exponent such that the + // significand is uniformly distributed and the exponent is distributed, well, + // exponentially (it's more likely to be close to 0 than far from 0). + // + // An easy way to do this is to say that the significand is the first S bits + // of R, and the exponent is determined by the number of trailing zeroes in R, + // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1; + // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.) + // + // This results in a small bias. Namely, if R has enough trailing zeroes, the + // significand and exponent will "overlap". As a concrete example, consider + // + // 20 X's 12 zeroes + // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000 + // + // Here the exponent is 2^-13 because R has 12 trailing zeroes. The + // significand is made up of the first 23 most-significant bits of R, which we + // observe contain 3 zeroes. This is biased because any random value with + // exponent 2^-12 will have a significand which ends in `000`. + // + // For f32s, this problem occurs only when there are more than 32-23 = 9 + // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the + // probability of a large bias (i.e. many trailing 0s in the significand) is + // exponentially low. So we deem this acceptable. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); - CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - // Perform the division using the float type with the same number of bits - // as the raw value to avoid overflow. - if (raw_value_size_in_bits == 32) { - elem_value = UIToFP(elem_value, b_->getFloatTy()); - elem_value = FDiv(elem_value, - llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); - } else { - elem_value = UIToFP(elem_value, b_->getDoubleTy()); - elem_value = FDiv( - elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); - } - - if (elem_ir_ty != elem_value->getType()) { - elem_value = FPTrunc(elem_value, elem_ir_ty); - } + const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics(); + const int bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics)); + + // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the + // implicit "1." at the beginning of the significand. + const int significand_bits = + llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1; + + llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()}, + {raw_value->getType()}, b_); + llvm::Value* significand = LShr(raw_value, bits - significand_bits); + + // Exponent bias is -127 for f32, meaning that if the exponent is E and the + // significand is S, then the value of the number is 2^(E - 127) * (1.S). + // + // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as + // E = 126 - cttz. + // + // For f64, this is all the same, except the bias is -1023. + // + // In IEEE floating point, the absolute value of the exponent bias equals + // the value of the largest possible exponent. + const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics); + llvm::Value* exponent = + Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz); + + // Now just slot everything into place! The `Trunc` is here because + // raw_value may be larger than our float destination. + elem_value = + BitCast(Trunc(Or(Shl(exponent, significand_bits), significand), + b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())), + elem_ir_ty); } // Convert the value for the requested distribution. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d3e2acaabd4f602171def70ccd3d4fd5adce0d0d..7d360fe38cfeda17878c363253c41883ec9fd64f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -216,8 +216,11 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); + // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. + // + // Precondition: raw_value has at least as many bits as hlo's element type. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 05980fe549c4d9235ff80916cfad77ab60e1c447..25c4f70d89b4ebc483a61f1e28c7a55eb31f4bdf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -765,6 +765,7 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", + "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd..9f0de3f794decb7b878b67c96030f8e11b0555fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -48,7 +48,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr); // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. -// This function works for both, sibling and producer-conumser multi-output +// This function works for both, sibling and producer-consumer multi-output // fusion. // So far, multi-output fusion is supported for loop fusions and reduce // input fusions only. It is up to the caller to ensure the instructions diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index cb13e727a44166ec564b58106c18b5c7f28a4af2..8f010ab27a6c99b97e7808218de908ce558b0fe7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -492,8 +492,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); result = InsertValue(result, value.first, {0}); result = InsertValue(result, value.second, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { result = FMul(lhs_value, rhs_value); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + result = Mul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -583,9 +586,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_imag = Imag(accum, &b_); llvm::Value* imag_sum = FAdd(accum_imag, value.second); updated_accum = InsertValue(updated_accum, imag_sum, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { llvm::Value* product = FMul(lhs_element, rhs_element); updated_accum = FAdd(accum, product); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + llvm::Value* product = Mul(lhs_element, rhs_element); + updated_accum = Add(accum, product); } Store(updated_accum, accum_address); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 9c8a1816040d99bd20af111e8e930149287ed146..6e00e4b4ff8c493f00fae3355215fb13fb5f4f10 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -195,6 +196,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 6e64549e7e1cad80b740452816028c730053623f..ae9e3169fd9b7a4655ab91ffb1589b845402ba8d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 60 +// Next ID: 62 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -175,6 +175,9 @@ message HloInstructionProto { // partners. bool is_host_transfer = 47; + // Whether this Sort instruction should be stable. + bool is_stable = 60; + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. @@ -196,6 +199,9 @@ message HloInstructionProto { // Options for TriangularSolve xla.TriangularSolveOptions triangular_solve_options = 59; + + // Describes how parameters behave with regards to replicas. + xla.ParameterReplication parameter_replication = 61; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 40fe91398be33f5681e1389e1b6fadcbd87487bb..817e15f9ff10a9b7e1a502265c85f70fdd681dd9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -296,7 +296,7 @@ void ComputeComputationPostOrder(HloComputation* computation, } // namespace void HloComputation::ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_group, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const { std::vector dfs_stack; @@ -320,66 +320,75 @@ void HloComputation::ComputeInstructionPostOrder( visited->insert({current, kVisiting}); - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for things like HLO stringification. - const auto& operands = current->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i]); - } - - for (HloInstruction* op : current->control_predecessors()) { - dfs_stack.emplace_back(op); - } - - // Add inputs for send->recv_done dependencies and all-reduce - // dependencies. - switch (current->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(current->channel_id()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } - break; + const auto get_channel_id = + [](HloInstruction* inst) -> absl::optional { + switch (inst->opcode()) { + case HloOpcode::kRecvDone: + return inst->channel_id(); + case HloOpcode::kAllReduce: + return inst->all_reduce_id(); + default: + return absl::nullopt; } - case HloOpcode::kAllReduce: { - auto all_reduce_id = current->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } + }; + + // When adding a predecessor to the dfs_stack, we need to also add its + // associated channel dependencies. + const auto add_dfs_stack = [&](HloInstruction* inst) { + auto channel_id = get_channel_id(inst); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + dfs_stack.emplace_back(cinst); } - break; + } else { + dfs_stack.emplace_back(inst); } - default: - break; + }; + + const auto add_predecessors = [&](HloInstruction* inst) { + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for things like HLO stringification. + const auto& operands = inst->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + add_dfs_stack(operands[i]); + } + + for (HloInstruction* op : inst->control_predecessors()) { + add_dfs_stack(op); + } + }; + + // If the current instruction is a channel instruction, add the dependencies + // from all associated instructions of the channel. + auto channel_id = get_channel_id(current); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + add_predecessors(cinst); + } + } else { + add_predecessors(current); } } } -HloComputation::ChannelDependencyMap +HloComputation::ChannelDependencyGroup HloComputation::ComputeChannelDependencies() const { - ChannelDependencyMap channel_dependency_map; + ChannelDependencyGroup channel_dependency_group; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { - case HloOpcode::kSend: { - channel_dependency_map[instruction->channel_id()].push_back( + case HloOpcode::kSend: + case HloOpcode::kRecvDone: + channel_dependency_group[instruction->channel_id()].push_back( instruction.get()); break; - } case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { - auto& dependencies = channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(instruction->operands(), - std::back_inserter(dependencies)); - absl::c_copy(instruction->control_predecessors(), - std::back_inserter(dependencies)); + channel_dependency_group[all_reduce_id.value()].push_back( + instruction.get()); } break; } @@ -387,11 +396,11 @@ HloComputation::ComputeChannelDependencies() const { break; } } - return channel_dependency_map; + return channel_dependency_group; } std::vector HloComputation::MakeInstructionPostOrder() const { - auto channel_dependency_map = ComputeChannelDependencies(); + auto channel_dependency_group = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; @@ -404,7 +413,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(channel_dependency_map, &post_order, + ComputeInstructionPostOrder(channel_dependency_group, &post_order, instruction.get(), &visited); } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fd1f990431a87ef27d3d7b0ae56ba73c444bc1cc..212dfa15a13185f1050103739fad8b560270d401 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -369,13 +369,13 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // all-reduce the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = + // Returns a map from channel-id to the group of instructions associated with + // the channel. These instructions will be considered as a single node for + // dependency purposes. Send and RecvDone are in the group, and AllReduces + // with the same channel id are in the group. + using ChannelDependencyGroup = absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; + ChannelDependencyGroup ComputeChannelDependencies() const; // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. @@ -438,7 +438,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_map, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 3b88e9745c27d6e1f2a46e5c83ac2e8bd8d05150..fe37ca6b3963430c765f27aede4f506366fc5d97 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -24,7 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,6 +39,7 @@ namespace xla { namespace { namespace m = match; +namespace op = xla::testing::opcode_matchers; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -668,5 +671,34 @@ TEST_F(HloComputationTest, DeepEquality) { EXPECT_FALSE(*computation_c == *computation_b); } +// Tests that cross-module AllReduce instructions are ordered before all their +// predecessors and after all their successors. +TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) { + const char* const hlo_string = R"( +HloModule Module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param = f32[128] parameter(0), sharding={maximal device=0} + crs0 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=0} + crs1 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=1} + add = f32[128] add(crs0, crs0), sharding={maximal device=0} + ROOT t = (f32[128], f32[128]) tuple(add, crs1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(), + ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(), + op::Add(), op::Tuple())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 070115604ba46dfe2de92b592a31e831ca2e1c87..b5d9e8e7f1a703d5d914a12d5226d53821071be6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -275,7 +275,7 @@ StatusOr MakeSelectHlo(HloInstruction* pred, StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, - int64 dimension_to_sort, HloComputation::Builder* builder, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module) { CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; HloComputation* compare_computation; @@ -293,7 +293,7 @@ StatusOr MakeSortHlo( compare_computation = module->DeepCloneComputation(new_module->entry_computation(), &context); return builder->AddInstruction(HloInstruction::CreateSort( - sort_shape, dimension_to_sort, operands, compare_computation)); + sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 36b8cdc7feff9143a041ad0beb3a0dda91589618..17b7a2da6a9da994ea2d496b549eec79278b56b5 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -126,10 +126,10 @@ StatusOr MakeSelectHlo(HloInstruction* pred, // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending -// order. +// order. 'is_stable' specifies whether the sorting should be stable. StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, - int64 dimension_to_sort, HloComputation::Builder* builder, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module); // Creates an R1 Constant HLO instruction of the given PrimitiveType with the diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e3059e02cf0e527522811920a09154afd32976f5..768e3afb3b80698061b62c4aadef09c20e2f286c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2363,7 +2363,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); TF_ASSERT_OK_AND_ASSIGN( - auto* sort, MakeSortHlo(keys_shape, {keys}, -1, &builder, module_.get())); + auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false, + &builder, module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); @@ -2385,7 +2386,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), - {keys, values}, 0, &builder, module_.get())); + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 691d5c1bbc3edb0aa47acc52d5752020068c3515..4d6487700b24cfd3b89aece58e5ad6d7bb43a800 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -1493,44 +1494,47 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { std::vector indices_to_sort(sort_dim_elements); std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); Status compare_status = Status::OK(); - std::stable_sort( - indices_to_sort.begin(), indices_to_sort.end(), - [sort, &compare_status, &embedded_evaluator, &literals_to_sort]( - int64 a, int64 b) { - std::vector literals; - literals.reserve(2 * sort->operand_count()); - for (int64 i = 0; i < sort->operand_count(); ++i) { - auto lhs = ExtractFromIndexPositions( - literals_to_sort[i], {a}, /*extract_as_scalar=*/true); - if (!lhs.ok()) { - compare_status = lhs.status(); - return false; - } - literals.push_back(std::move(lhs.ValueOrDie())); - auto rhs = ExtractFromIndexPositions( - literals_to_sort[i], {b}, /*extract_as_scalar=*/true); - if (!rhs.ok()) { - compare_status = rhs.status(); - return false; - } - literals.push_back(std::move(rhs.ValueOrDie())); - } - std::vector literal_ptrs; - absl::c_transform( - literals, std::back_inserter(literal_ptrs), - [](const Literal& literal) { return &literal; }); - - auto computed_result = - embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); - // Clear visit states so that we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - if (!computed_result.ok()) { - compare_status = computed_result.status(); - return false; - } - return computed_result.ValueOrDie().Get({}); - }); + auto comparator = [sort, &compare_status, &embedded_evaluator, + &literals_to_sort](int64 a, int64 b) { + std::vector literals; + literals.reserve(2 * sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a}, + /*extract_as_scalar=*/true); + if (!lhs.ok()) { + compare_status = lhs.status(); + return false; + } + literals.push_back(std::move(lhs.ValueOrDie())); + auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b}, + /*extract_as_scalar=*/true); + if (!rhs.ok()) { + compare_status = rhs.status(); + return false; + } + literals.push_back(std::move(rhs.ValueOrDie())); + } + std::vector literal_ptrs; + absl::c_transform(literals, std::back_inserter(literal_ptrs), + [](const Literal& literal) { return &literal; }); + + auto computed_result = + embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); + // Clear visit states so that we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + if (!computed_result.ok()) { + compare_status = computed_result.status(); + return false; + } + return computed_result.ValueOrDie().Get({}); + }; + if (Cast(sort)->is_stable()) { + std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), + comparator); + } else { + std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator); + } if (!compare_status.ok()) { return compare_status; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index fb8cd299cef06d549130cd56dd2c430c4c1a0387..383921fde22242b6ede95a6554f2348ab6fd4277 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -111,6 +111,24 @@ class HloEvaluatorTest : public HloTestBase { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } + void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0, + Literal src1, Literal src2) { + HloComputation::Builder b(TestName()); + auto operand0 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src0))); + auto operand1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src1))); + auto operand2 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src2))); + b.AddInstruction(HloInstruction::CreateTernary( + expected.shape(), opcode, operand0, operand1, operand2)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + protected: explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} HloEvaluator evaluator_; @@ -152,6 +170,33 @@ TEST_P(HloEvaluatorBf16Test, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +// Verifies that clamping of int64 does not cause loss of precision +TEST_P(HloEvaluatorBf16Test, DoesClampInt64) { + auto ones = [](int bits) { return (int64{1} << bits) - 1; }; + + auto low = + LiteralUtil::CreateR2({{0, ones(54)}, {ones(54), ones(58)}}); + auto value = LiteralUtil::CreateR2({{0, ones(56)}, {0, ones(58)}}); + auto high = LiteralUtil::CreateR2( + {{ones(54), ones(55)}, {ones(56), ones(58)}}); + + Shape shape = low.shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = + LiteralUtil::CreateR2({{0, ones(55)}, {ones(54), ones(58)}}); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); @@ -254,6 +299,20 @@ TEST_F(HloEvaluatorTest, DoesDivideInt64) { TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } + +TEST_F(HloEvaluatorTest, DoesClampS64) { + auto low = LiteralUtil::CreateR1( + {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL}); + auto value = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL}); + auto high = LiteralUtil::CreateR1( + {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL}); + auto expected = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL}); + TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low), + std::move(value), std::move(high)); +} + TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8def61dc63db2c55f926a04fe097988af4417c1a..d516a6258c80bda168ef4c6fd976e60946eb8b5b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #include +#include #include "absl/algorithm/container.h" #include "absl/base/casts.h" @@ -893,9 +894,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleShiftRightLogical(shrl); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + // Special case for integral type due to MSVC's std::isnan being unable to + // handle integral type. + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return static_cast( + std::min(high, std::max(value, low))); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template ::value && + !std::is_integral::value>::type* = + nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { @@ -903,7 +924,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(NAN); } return static_cast( - std::fmin(high, std::fmax(value, low))); + std::min(high, std::max(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -2670,12 +2691,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& high = parent_->GetEvaluatedLiteralFor(random->operand(1)); - std::uniform_real_distribution generator( - low.Get({}), high.Get({})); - + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + auto low_val = low.Get({}); + auto high_val = high.Get({}); + std::uniform_real_distribution generator(low_val, high_val); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span /*indexes*/) { - return generator(parent_->engine_); + while (true) { + NativeT v = generator(parent_->engine_); + if (v != high_val) { + return v; + } + } })); break; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index e6f446c92687d0b27fcf1cdc4f38919e64c1035b..49300b3ffe2f755d103af7877ab3fee5298eeb3e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -536,7 +535,12 @@ stylesheet=< } } - return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); + // Browsers require that we URI-encode the contents of our data URI. (It + // seems this was a relatively recent change?) In practice, this means that we + // need to escape '#'. + return StrFormat( + fmt, graph_label, + absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}})); } string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } @@ -1451,9 +1455,6 @@ string SaveGraph(const string& graph, case GraphRendererInterface::DOT_GRAPH: file_extension = ".dot"; break; - case GraphRendererInterface::TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; } string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); @@ -1491,25 +1492,27 @@ string ExportGraph(const string& graph, } // namespace +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options) { + DebugOptions default_debug_options; + return HloDotDumper(&computation, options.label, + options.debug_options ? *options.debug_options + : default_debug_options, + options.show_backend_config, options.profile, + NodeFilter()) + .Dump(); +} + string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; - string graph; - if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder(debug_options); - TF_CHECK_OK(builder.AddComputation(computation)); - CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), - &graph)); - graph_kind = GraphRendererInterface::TF_GRAPHDEF; - } else { - graph = - HloDotDumper(&computation, label, debug_options, show_backend_config, - hlo_execution_profile, NodeFilter()) - .Dump(); - graph_kind = GraphRendererInterface::DOT_GRAPH; - } + string graph = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index b5444a32b18bfe75d048009a49f170930befd12d..563cea42371d370b4c9ea739418692fd74dca799 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -26,13 +26,23 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { +// Converts a HLO module to a DOT (graphviz) graph. Returns the dot graph as +// a string. +struct DotGraphOptions { + absl::string_view label; + const DebugOptions* debug_options = nullptr; + const HloExecutionProfile* profile = nullptr; + bool show_backend_config = false; +}; +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options); + // Abstract interface for classes that render HLO graphs (e.g. DOT graph, -// tensorflow GraphDef). +// tensorflow GraphDef) to files or services. class GraphRendererInterface { public: enum GraphKind { DOT_GRAPH, - TF_GRAPHDEF, }; virtual ~GraphRendererInterface() = default; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index aa1f3a2421f52c45145731a0203bd46f3ea574cf..33c2270eb0a847d088776a2d9d67e341a69dbae2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -214,7 +214,7 @@ StatusOr> HloInstruction::CreateFromProto( << proto.called_computation_ids_size(); auto sort_operands = all_operands(); instruction = CreateSort(shape, proto.dimensions(0), all_operands(), - computations(0)); + computations(0), proto.is_stable()); break; } case HloOpcode::kTranspose: @@ -304,6 +304,10 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), shape, proto.name()); + if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) { + instruction->set_parameter_replicated_at_leaf_buffers( + proto.parameter_replication().replicated_at_leaf_buffers()); + } break; case HloOpcode::kGetTupleElement: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -1170,9 +1174,10 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare) { + absl::Span operands, HloComputation* compare, + bool is_stable) { return absl::make_unique(shape, dimension, operands, - compare); + compare, is_stable); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -3321,6 +3326,19 @@ int64 HloInstruction::parameter_number() const { return Cast(this)->parameter_number(); } +void HloInstruction::set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + return Cast(this) + ->set_parameter_replicated_at_leaf_buffers( + parameter_replicated_at_leaf_buffers); +} + +const absl::optional>& +HloInstruction::parameter_replicated_at_leaf_buffers() const { + return Cast(this) + ->parameter_replicated_at_leaf_buffers(); +} + int64 HloInstruction::tuple_index() const { return Cast(this)->tuple_index(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f3a50c59362a9975c10f0d21356a387422fc10d1..8ac1636d7159c7cb478856737d93387be49f1ba1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -384,6 +385,14 @@ class HloInstruction { // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. + // + // The parameters to the instruction are interpreted as follows: + // + // - If `distribution` is RNG_UNIFORM, generates a number in range + // [param0, param1). + // + // - If `distribution` is RNG_NORMAL, generates a normally-distributed value + // with mean `param0` and standard deviation `param1`. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); @@ -493,7 +502,7 @@ class HloInstruction { // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor - // conssits of 0(s) in `shape`. + // consists of 0(s) in `shape`. static std::unique_ptr CreateCollectivePermute( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); @@ -678,10 +687,11 @@ class HloInstruction { // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at // specific index positions which should be compared, and should return a - // PRED. + // PRED. 'is_stable' specifies whether stable sorting is required. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare); + absl::Span operands, HloComputation* compare, + bool is_stable); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -1459,6 +1469,15 @@ class HloInstruction { // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; + // Delegates to + // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers); + + // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. + const absl::optional>& + parameter_replicated_at_leaf_buffers() const; + // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 92a74187c50db011b3c50ed6661354b5d33aef9e..905a6fe08b4430ad862edf0886a57c9f7e9f7977 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -218,11 +218,14 @@ HloInstructionProto HloTriangularSolveInstruction::ToProto() const { std::vector HloTriangularSolveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("left_side=", triangular_solve_options_.left_side()), - StrCat("lower=", triangular_solve_options_.lower()), - StrCat("unit_diagonal=", triangular_solve_options_.unit_diagonal()), - StrCat("transpose_a=", TriangularSolveOptions_Transpose_Name( - triangular_solve_options_.transpose_a()))}; + return { + StrCat("left_side=", + triangular_solve_options_.left_side() ? "true" : "false"), + StrCat("lower=", triangular_solve_options_.lower() ? "true" : "false"), + StrCat("unit_diagonal=", + triangular_solve_options_.unit_diagonal() ? "true" : "false"), + StrCat("transpose_a=", TriangularSolveOptions_Transpose_Name( + triangular_solve_options_.transpose_a()))}; } bool HloTriangularSolveInstruction::IdenticalSlowPath( @@ -659,8 +662,11 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare) - : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { + absl::Span operands, HloComputation* compare, + bool is_stable) + : HloInstruction(HloOpcode::kSort, shape), + dimensions_({dimension}), + is_stable_(is_stable) { for (auto* value : operands) { AppendOperand(value); } @@ -672,12 +678,18 @@ HloInstructionProto HloSortInstruction::ToProto() const { for (int64 dimension : dimensions_) { proto.add_dimensions(dimension); } + proto.set_is_stable(is_stable()); return proto; } std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; + std::vector attrs; + attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}")); + if (is_stable()) { + attrs.push_back("is_stable=true"); + } + return attrs; } bool HloSortInstruction::IdenticalSlowPath( @@ -688,14 +700,17 @@ bool HloSortInstruction::IdenticalSlowPath( if (dimensions() != casted_other.dimensions()) { return false; } + if (is_stable() != casted_other.is_stable()) { + return false; + } return eq_computations(to_apply(), other.to_apply()); } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(shape, dimensions(0), - new_operands, to_apply()); + return absl::make_unique( + shape, dimensions(0), new_operands, to_apply(), is_stable()); } HloTransposeInstruction::HloTransposeInstruction( @@ -1523,9 +1538,30 @@ HloParameterInstruction::HloParameterInstruction(int64 parameter_number, HloInstructionProto HloParameterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_parameter_number(parameter_number_); + if (parameter_replicated_at_leaf_buffers_) { + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers( + replicated); + } + } return proto; } +std::vector HloParameterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + if (!parameter_replicated_at_leaf_buffers_) { + return result; + } + std::vector buffers_replicated_strs; + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + buffers_replicated_strs.push_back(replicated ? "true" : "false"); + } + result.push_back(StrCat("parameter_replication={", + StrJoin(buffers_replicated_strs, ","), "}")); + return result; +} + string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -2013,6 +2049,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (batch_group_count_ != casted_other.batch_group_count_) { return false; } + if (layout_constrained() != casted_other.layout_constrained()) { + return false; + } + if (layout_constrained()) { + for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) { + if (!ShapeUtil::Equal(operand_shapes_with_layout_[i], + casted_other.operand_shapes_with_layout_[i])) { + return false; + } + } + } return custom_call_target_ == casted_other.custom_call_target_ && opaque_ == casted_other.opaque_; } @@ -2023,6 +2070,10 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { auto cloned = absl::make_unique( shape, new_operands, custom_call_target(), opaque()); + if (layout_constrained()) { + cloned->layout_constrained_ = true; + cloned->operand_shapes_with_layout_ = operand_shapes_with_layout(); + } if (window_ != nullptr) { cloned->set_window(*window_); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index a0f2b46ba41cc6a60e28050d3cdd5e4e4583a875..4d23cb671f24623f56faa9b69015cef21752a799 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -447,7 +447,7 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, absl::Span operands, - HloComputation* compare); + HloComputation* compare, bool is_stable); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -460,6 +460,7 @@ class HloSortInstruction : public HloInstruction { HloInstruction* mutable_keys() { return mutable_operand(0); } // Returns the number of value operands. int64 values_count() const { return operand_count() - 1; } + bool is_stable() const { return is_stable_; } private: std::vector ExtraAttributesToStringImpl( @@ -474,6 +475,7 @@ class HloSortInstruction : public HloInstruction { HloCloneContext* context) const override; std::vector dimensions_; + bool is_stable_; }; class HloTransposeInstruction : public HloInstruction { @@ -815,10 +817,28 @@ class HloParameterInstruction : public HloInstruction { explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, const string& name); int64 parameter_number() const { return parameter_number_; } + + // Sets and gets the whether all replicas will receive the same parameter data + // for each leaf buffer in data parallelism. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + CHECK_EQ(ShapeUtil::GetLeafCount(shape()), + parameter_replicated_at_leaf_buffers.size()); + parameter_replicated_at_leaf_buffers_.emplace( + parameter_replicated_at_leaf_buffers.begin(), + parameter_replicated_at_leaf_buffers.end()); + } + const absl::optional>& + parameter_replicated_at_leaf_buffers() const { + return parameter_replicated_at_leaf_buffers_; + } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -832,6 +852,10 @@ class HloParameterInstruction : public HloInstruction { HloCloneContext* context) const override; int64 parameter_number_ = 0; + + // Specifies whether each buffer has the same parameter value on all replicas + // in data parallelism. + absl::optional> parameter_replicated_at_leaf_buffers_; }; class HloGetTupleElementInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 20dbed07c546b3ec465e3b57c73a43c6c8f98efc..f448571082e52e4b81db1c68d1e1470935386139 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -82,6 +82,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); @@ -183,6 +184,7 @@ class HloParser { kWindow, kConvolutionDimensionNumbers, kSharding, + kParameterReplication, kInstructionList, kSliceRanges, kPaddingConfig, @@ -247,6 +249,7 @@ class HloParser { bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseParameterReplication(ParameterReplication* parameter_replication); // Parses the metadata behind a kDOmain instruction. bool ParseDomain(DomainData* domain); @@ -644,6 +647,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, std::unordered_map attrs; optional sharding; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional parameter_replication; + attrs["parameter_replication"] = {/*required=*/false, + AttrTy::kParameterReplication, + ¶meter_replication}; optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; @@ -895,6 +902,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; + optional is_stable = false; + attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable}; optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; @@ -902,8 +911,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), operands, to_apply.value())); + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, dimensions->at(0), operands, + to_apply.value(), is_stable.value())); break; } case HloOpcode::kTuple: { @@ -1675,6 +1685,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); } + if (parameter_replication) { + int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); + const auto& replicated = + parameter_replication->replicated_at_leaf_buffers(); + if (leaf_count != replicated.size()) { + return Error(lexer_.GetLoc(), + StrCat("parameter has ", leaf_count, + " leaf buffers, but parameter_replication has ", + replicated.size(), " elements.")); + } + instruction->set_parameter_replicated_at_leaf_buffers(replicated); + } if (predecessors) { for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); @@ -1834,6 +1856,32 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// parameter_replication ::= +// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}' +bool HloParser::ParseParameterReplication( + ParameterReplication* parameter_replication) { + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start parameter_replication attribute")) { + return false; + } + + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (lexer_.GetKind() == TokKind::kw_true) { + parameter_replication->add_replicated_at_leaf_buffers(true); + } else if (lexer_.GetKind() == TokKind::kw_false) { + parameter_replication->add_replicated_at_leaf_buffers(false); + } else { + return false; + } + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + + return ParseToken(TokKind::kRbrace, + "expected '}' to end parameter_replication attribute"); +} + // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' bool HloParser::ParseDomain(DomainData* domain) { @@ -2684,6 +2732,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kParameterReplication: { + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(parameter_replication); + return true; + } case AttrTy::kInstructionList: { std::vector result; if (!ParseInstructionNames(&result)) { @@ -3785,6 +3842,21 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr> HloParser::ParseParameterReplicationOnly() { + lexer_.Lex(); + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after parameter replication"); + } + return std::vector( + parameter_replication.replicated_at_leaf_buffers().begin(), + parameter_replication.replicated_at_leaf_buffers().end()); +} + StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; @@ -3900,6 +3972,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr> ParseParameterReplication(absl::string_view str) { + HloParser parser(str); + return parser.ParseParameterReplicationOnly(); +} + StatusOr ParseWindow(absl::string_view str) { HloParser parser(str); return parser.ParseWindowOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 450a54c54c156c2ae27475d145a8e83dc841b431..a96260b4d75e515a4cb23d315444142cae1b9587 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,11 +44,16 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, -// e.g., "{replicated}". +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +StatusOr> ParseParameterReplication(absl::string_view str); + // Parses the result of window_util::ToString(const Window&). StatusOr ParseWindow(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 203a7dba22110063b54467bd8e550fa8f23c68d1..8e3f1e44b9562334130aa565ed447a78899fad53 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -63,6 +63,19 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } +)" +}, +// parameter replication +{ +"ParamReplication", +R"(HloModule param_replication_module + +ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) { + %a = f32[] parameter(0), parameter_replication={true} + %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} + ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) +} + )" }, // pred constant @@ -1145,6 +1158,24 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare } +)" +}, +// Sort (Key) is_stable=true +{ +"SortKeyStable", +R"(HloModule sort + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare +} + )" }, // Conditional @@ -2692,5 +2723,16 @@ TEST_F(HloParserTest, NegativeParameterNumber) { ::testing::HasSubstr("parameter number must be >= 0")); } +TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) { + const string hlo_string = + "par0 = (f32[3,5], f32[]) parameter(0), " + "parameter_replication={true,false,true}"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter has 2 leaf buffers, but " + "parameter_replication has 3 elements")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 0fced7f15bdaf1dbe349e3b0fc6ada68393c6512..b7f507b1184dbe021effc1102a68040286480ed2 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -77,28 +77,51 @@ std::unique_ptr HloReachabilityMap::Build( const HloComputation* computation) { const auto& all = computation->MakeInstructionPostOrder(); auto result = absl::make_unique(all); - auto channel_dependency_map = computation->ComputeChannelDependencies(); + auto channel_group = computation->ComputeChannelDependencies(); - std::vector inputs; for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); + std::vector inputs; + const auto add_input = [&channel_group, &inputs](HloInstruction* input) { + inputs.push_back(input); + if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { + auto it = channel_group.find(*input->all_reduce_id()); + if (it != channel_group.end()) { + inputs.insert(inputs.end(), it->second.begin(), it->second.end()); + } + } + }; + + const auto add_dependencies = [&add_input](const HloInstruction* hlo) { + for (HloInstruction* operand : hlo->operands()) { + add_input(operand); + } + for (HloInstruction* predecessor : hlo->control_predecessors()) { + add_input(predecessor); + } + }; + + add_dependencies(hlo); switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(hlo->channel_id()); + if (it != channel_group.end()) { + for (HloInstruction* channel : it->second) { + if (channel->opcode() == HloOpcode::kSend) { + add_input(channel); + } + } } break; } case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(all_reduce_id.value()); + if (it != channel_group.end()) { + for (HloInstruction* all_reduce : it->second) { + add_dependencies(all_reduce); + } } } break; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 84399f17e5e0b0b1f29cded17b605571bcfa8843..5a5401e351384867016a3a9addfd43d57091848c 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -176,7 +176,7 @@ StatusOr HloRunner::Execute( TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, ExecuteWithDeviceBuffers( - /*module=*/std::move(executable), + /*executable=*/executable.get(), /*arguments=*/argument_buffers, /*profile=*/profile)); return TransferLiteralFromDevice(result); @@ -235,7 +235,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { // Get service run options. @@ -254,7 +254,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a6e6015d6a5e2ad6e85cf2411f1a740c0987d8b4..098989cd4c78fb5ad57cd6700fbf99c50064f225 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -60,7 +60,7 @@ class HloRunner { // The number of times the infeed literal should be fed to the HLO module. // For a clean exit, this should match the iterations-per-loop parameter // used when generating the HLO module proto (that is usually the main - // while bounary counter). A value higher then iterations-per-loop would + // while boundary counter). A value higher then iterations-per-loop would // lead to infeed threads feeding to a gone computation, while a lower // value would trigger a stuck ExecuteReplicated() call (the computation // will be trying to infeed data which will never come). @@ -144,13 +144,16 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // In the following two calls, "executable" is not a unique_ptr to allow + // reuse of the Executable. This call may update the profile information in + // *executable. StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc deleted file mode 100644 index 6925dc37dbe9dc90e79d315cf41a3416e2084c81..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -LIcensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using absl::StrAppend; -using absl::StrCat; -using tensorflow::GraphDef; -using tensorflow::NodeDef; -using tensorflow::TensorShapeProto; - -string GetOpDefName(const HloInstruction* instruction) { - string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok - name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); - - if (instruction->opcode() == HloOpcode::kFusion) { - string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, absl::string_view(fusion_name).substr(1)); - } - return name; -} - -TensorShapeProto GetTensorShape(const HloInstruction* instruction) { - TensorShapeProto tensor_shape; - const Shape& shape = instruction->shape(); - for (auto dim : shape.dimensions()) { - tensor_shape.add_dim()->set_size(dim); - } - return tensor_shape; -} - -string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } - -void CleanNodeName(string* name) { - name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); - const string chars_to_replace = "<>[]"; - auto pred = [&](char c) { - return absl::c_linear_search(chars_to_replace, c); - }; - std::replace_if(name->begin(), name->end(), pred, '_'); -} - -} // namespace - -HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) - : debug_options_(debug_options) {} - -Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { - VLOG(2) << "Adding computation " << computation.name(); - for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto* instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - } - for (auto* instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - return Status::OK(); -} - -const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } - -const string& HloTfGraphBuilder::GetNodeNameForInstruction( - const HloInstruction* instruction) { - if (ContainsKey(instruction_to_node_name_, instruction)) { - return instruction_to_node_name_[instruction]; - } - auto append = [](string* str, const string& other) { - if (str->empty()) { - *str = other; - } else if (!other.empty()) { - StrAppend(str, "/", other); - } - }; - string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes()) { - auto device = instruction->sharding_unique_device(); - if (device) { - node_name = StrCat("dev", *device); - } - } - // If an instruction is fused, put it in the subgraph of the fusion; - // otherwise, put it in the computation subgraph. - const HloComputation* computation = instruction->parent(); - if (computation->IsFusionComputation()) { - append(&node_name, - GetNodeNameForInstruction(computation->FusionInstruction())); - } else { - append(&node_name, computation->name()); - if (!instruction->metadata().op_name().empty()) { - // Always make computations contain TF ops but not the other way around. - append(&node_name, instruction->metadata().op_name()); - } - } - string instruction_name = instruction->name(); - if (instruction->opcode() == HloOpcode::kParameter) { - StrAppend(&instruction_name, ".", instruction->parameter_number()); - } - append(&node_name, instruction_name); - CleanNodeName(&node_name); - auto ret = - instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); - CHECK(ret.second); - return ret.first->second; -} - -void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, - NodeDef* node_def) const { - auto& attrs = *node_def->mutable_attr(); - - // Set the number of arguments for instructions that have variadic operands. - if (HloOpcodeIsVariadic(instruction->opcode())) { - tensorflow::AttrValue attr_value; - attr_value.set_i(instruction->operands().size()); - attrs["arg_num"] = attr_value; - } - - // Set the node type. - attrs["type"].set_s( - xla::PrimitiveType_Name(instruction->shape().element_type())); - - // Set the framework op (e.g. Tensorflow op) that generated this XLA op. - attrs["tf_op_type"].set_s(instruction->metadata().op_type()); - attrs["tf_op_name"].set_s(instruction->metadata().op_name()); - - // Set the shape of the output tensor. "_output_shapes" is a special attribute - // name used by Tensorboard for shapes of output tensors. - tensorflow::AttrValue shapes; - *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); - attrs["_output_shapes"] = shapes; - - // Set the layout. - if (LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (instruction->shape().IsTuple()) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else if (instruction->shape().has_layout()) { - // For non-tuples, only emit the layout when the shape has a Layout. - // This extra check is required because LayoutUtil::HasLayout ignores - // token, opaque types etc. - layout_string = instruction->shape().layout().ToString(); - } - attrs["layout"].set_s(layout_string); - } - - // Set op-specific attributes. - switch (instruction->opcode()) { - case HloOpcode::kConcatenate: - case HloOpcode::kBroadcast: - case HloOpcode::kReduce: - case HloOpcode::kReverse: - case HloOpcode::kTranspose: - for (auto dim : instruction->dimensions()) { - attrs["dims"].mutable_list()->add_i(dim); - } - break; - case HloOpcode::kGetTupleElement: - attrs["index"].set_i(instruction->tuple_index()); - break; - case HloOpcode::kRng: - attrs["dist"].set_s( - RandomDistribution_Name(instruction->random_distribution())); - break; - case HloOpcode::kConstant: - if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s(instruction->literal().GetAsString({})); - } - break; - case HloOpcode::kCustomCall: - attrs["custom_call_target"].set_s(instruction->custom_call_target()); - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - attrs["channel_id"].set_i(instruction->channel_id()); - break; - default: - break; - } -} - -Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { - if (!visited_instructions_.insert(instruction).second) { - // Skip instructions that have already been added. - return Status::OK(); - } - - NodeDef* node_def = graph_def_.add_node(); - node_def->set_name(GetNodeNameForInstruction(instruction)); - node_def->set_op(GetOpDefName(instruction)); - - auto device = instruction->sharding_unique_device(); - if (device) { - node_def->set_device(GetDeviceName(*device)); - } - SetNodeAttrs(instruction, node_def); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto* fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); - } - } - // Add all edges including control edges. - for (unsigned i = 0; i < instruction->operands().size(); ++i) { - *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); - } - // Called computations are control dependencies. - for (const auto* called_computation : instruction->called_computations()) { - *node_def->add_input() = StrCat( - "^", GetNodeNameForInstruction(called_computation->root_instruction())); - } - return Status::OK(); -} - -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h deleted file mode 100644 index c4876b852e32d34693202f4023aa20ad2b301ffd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" - -namespace xla { -namespace hlo_graph_dumper { - -// This constructs a tensorflow graph for HLO computations. -class HloTfGraphBuilder { - public: - HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); - - // Adds a computation to the graph. - Status AddComputation(const HloComputation& computation); - - const tensorflow::GraphDef& GetGraphDef() const; - - private: - // Gets the node name of an instruction. The node name is hierarchical. For - // example, if an instruction is fused, it will be put in a subgraph of the - // fusion instruction. - const string& GetNodeNameForInstruction(const HloInstruction* instruction); - - void SetNodeAttrs(const HloInstruction* instruction, - tensorflow::NodeDef* node_def) const; - - Status AddInstruction(const HloInstruction* instruction); - - DebugOptions debug_options_; - tensorflow::GraphDef graph_def_; - // This records instructions that have been visited. - std::unordered_set visited_instructions_; - // A cache that maps instruction to the node name. - std::unordered_map instruction_to_node_name_; -}; - -} // namespace hlo_graph_dumper -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc deleted file mode 100644 index 498abcfe04d963575fb9200443efb7d911a6293e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using ::tensorflow::GraphDef; - -class HloTfGraphBuilderTest : public HloTestBase { - protected: - HloTfGraphBuilderTest() {} - HloTfGraphBuilder generator_; - - // Create a computation which takes a scalar and returns its negation. - std::unique_ptr CreateNegateComputation() { - auto builder = HloComputation::Builder("Negate"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - return builder.Build(); - } - - // Creates a computation which calls map with the given computation. - std::unique_ptr CreateMapComputation( - HloComputation *map_computation) { - auto builder = HloComputation::Builder("Map"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map_computation)); - return builder.Build(); - } - Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); -}; - -static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, - const string &attr_name) { - auto attr = node.attr().find(attr_name); - CHECK(attr != node.attr().end()); - return attr->second; -} - -TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { - auto builder = HloComputation::Builder("Concatenate"); - Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, shape, "param1")); - builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - const auto &node = graph_def.node(2); - EXPECT_EQ(node.name(), "Concatenate/concatenate"); - - // Check dimensions. - auto dims_value = GetNodeAttr(node, "dims"); - EXPECT_EQ(dims_value.list().i_size(), 1); - EXPECT_EQ(dims_value.list().i(0), 1); - - // Check shapes. - auto shape_value = GetNodeAttr(node, "_output_shapes"); - EXPECT_EQ(shape_value.list().shape_size(), 1); - EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); -} - -TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { - auto builder = HloComputation::Builder("Const"); - HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - OpMetadata metadata; - metadata.set_op_name("x"); - metadata.set_op_type("y"); - instruction->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 1); - const auto &node = graph_def.node(0); - EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); - EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); -} - -TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { - auto negate_computation = CreateNegateComputation(); - TF_CHECK_OK(generator_.AddComputation(*negate_computation)); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 2); - EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); - EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); - EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); - EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); - EXPECT_EQ(graph_def.node(1).input_size(), 1); - EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); -} - -TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - auto ge = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - OpMetadata metadata; - metadata.set_op_name("x/y"); - metadata.set_op_type("Y"); - ge->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { - // Create computations with a diamond-shaped callgraph. - auto negate_computation = CreateNegateComputation(); - auto map1_computation = CreateMapComputation(negate_computation.get()); - auto map2_computation = CreateMapComputation(negate_computation.get()); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); - auto computation = builder.Build(); - TF_CHECK_OK(generator_.AddComputation(*computation)); - EXPECT_GT(generator_.GetGraphDef().node_size(), 0); -} - -TEST_F(HloTfGraphBuilderTest, TokenHasNoLayout) { - auto builder = HloComputation::Builder("Token"); - auto token = builder.AddInstruction(HloInstruction::CreateToken()); - OpMetadata metadata; - metadata.set_op_name("x"); - metadata.set_op_type("y"); - token->set_metadata(metadata); - TF_ASSERT_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - ASSERT_EQ(graph_def.node_size(), 1); - const auto &node = graph_def.node(0); - ASSERT_EQ(GetNodeAttr(node, "type").s(), "TOKEN"); - ASSERT_EQ(GetNodeAttr(node, "layout").s(), ""); - ASSERT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); - ASSERT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); -} - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.cc b/tensorflow/compiler/xla/service/op_expander_pass.cc index 87f0886a9737807bdf6f00921b813d21b69a18cc..02c9d4b387b112be39c204d35fe4fa1013ed064c 100644 --- a/tensorflow/compiler/xla/service/op_expander_pass.cc +++ b/tensorflow/compiler/xla/service/op_expander_pass.cc @@ -36,6 +36,9 @@ StatusOr OpExpanderPass::Run(HloModule* module) { for (HloInstruction* inst : matching_instructions) { TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandInstruction(inst)); + if (expanded_root == nullptr) { + continue; + } TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); } diff --git a/tensorflow/compiler/xla/service/op_expander_pass.h b/tensorflow/compiler/xla/service/op_expander_pass.h index 794849d354bef6e2b0b79e6d03af4ed851dfbdb3..276e3d70b8ecd8742e0b277698765063198fe872 100644 --- a/tensorflow/compiler/xla/service/op_expander_pass.h +++ b/tensorflow/compiler/xla/service/op_expander_pass.h @@ -33,7 +33,9 @@ class OpExpanderPass : public HloModulePass { // Returns `true` if `instruction` should be expanded by this pass. virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; - // Returns a replacement for `instruction`. + // Returns a replacement for `instruction`, or nullptr if no replacement is + // neeeded (e.g. only the to_apply subcomputation of the instruction was + // modified). virtual StatusOr ExpandInstruction( HloInstruction* instruction) = 0; }; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a570ee346d2f50e6eb2a592452bdec423556a916..3f4456c1bbf0f620609459256424b9cb30a04e13 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -836,7 +836,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(larger_shape)); } if (small_is_dynamic != large_is_dynamic) { - if ((small_dimension_size == 1 && !small_is_dynamic) || + if (small_dimension_size == large_dimension_size || + (small_dimension_size == 1 && !small_is_dynamic) || (large_dimension_size == 1 && !large_is_dynamic)) { // Do nothing. It's OK when the size-1 dimension is not static. } else { diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..1aa7e5fe7c0d57ee3303480e4727c456727f64c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Looks for a iota operand that can be used as tie breaker in the computation. +// If no matching iota operand is found, a iota operand is added to Sort. The +// comparison computation is adjusted to break ties using the values from the +// iota operand. +StatusOr StableSortExpander::ExpandInstruction( + HloInstruction* instruction) { + auto* sort = Cast(instruction); + HloComputation* computation = sort->parent(); + + HloInstruction* expanded_sort = nullptr; + absl::flat_hash_set used_indices; + int64 iota_index = -1; + for (const HloInstruction* operand : sort->operands()) { + // We can only use the iota operand if it has an iota dimension which is the + // same as the dimension to sort. Also it should have an integral type that + // is large enough for the number of elements in the sort dimension. For + // now, we only allow S32, because we expect to find a S32 iota operand for + // all Sort ops which are created by TopK. + // TODO(b/122298745): Also support other types. + if (operand->opcode() == HloOpcode::kIota && + Cast(operand)->iota_dimension() == + sort->sort_dimension() && + operand->shape().element_type() == S32) { + iota_index = sort->operand_index(operand); + break; + } + } + + // If there is currently no iota operand which we could use for making the + // sort stable, we will have to add a new such operand. + if (iota_index == -1) { + Shape iota_shape = sort->operand(0)->shape(); + // We might need to use S64 if the number of elements in the sort dimension + // is bigger than 2^31 - 1. + // TODO(b/122298745): Handle Sort ops where S32 is too small for the number + // of elements in the sort dimension. + if (iota_shape.dimensions(sort->sort_dimension()) > + std::numeric_limits::max()) { + return Unimplemented( + "Stable sorting of more than 2^31-1 elements is not implemented"); + } + iota_shape.set_element_type(S32); + auto iota = computation->AddInstruction( + HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); + + // Create a new comparator. + auto comparator = sort->to_apply(); + absl::flat_hash_map> + replacements; + std::vector> extra_parameters; + std::vector extra_parameter_ptrs; + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".lhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2 + 1, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".rhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements), + extra_parameter_ptrs))); + + // Replace the original sort op. + std::vector new_operands(sort->operands().begin(), + sort->operands().end()); + new_operands.push_back(iota); + std::vector new_shapes = sort->operand_count() == 1 + ? std::vector{sort->shape()} + : sort->shape().tuple_shapes(); + new_shapes.push_back(iota_shape); + Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, new_operands)); + + // Add a "wrapper" around the new sort op to make sure we have the same + // shape as before. For the rank 1 case, we only need a GetTupleElement, + // otherwise we create a Tuple consisting of GetTupleElements of the new + // sort. + std::vector tuple_elements; + tuple_elements.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + sort->operand(i)->shape(), new_sort, i))); + } + expanded_sort = tuple_elements[0]; + if (tuple_elements.size() > 1) { + expanded_sort = computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + } + sort = Cast(new_sort); + iota_index = sort->operand_count() - 1; + } + + // Modify the computation to break ties using the iota operand. + auto comparator = sort->to_apply(); + std::vector instructions_postorder = + comparator->MakeInstructionPostOrder(); + absl::flat_hash_map replacements; + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second; + }; + HloInstruction* old_root = comparator->root_instruction(); + // The comparison computation gets 2 * n parameters (n being the number of + // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two + // different scalars of operand i of Sort which are to be compared. The + // comparison computation should induce a strict weak order, so if + // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to + // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the + // values to be compared are equivalent, and perform a tie-breaker comparison. + // + // We clone each instruction with at least one operand, but use as new + // operands of the instruction the replacements of the original operands. + // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This + // should make sure that the cloned root instruction gives the result of the + // comparison computation when being called with each scalar pair reversed. + // parameters corresponding to the iota operand. + for (int64 i = 0; i < comparator->num_parameters(); ++i) { + replacements[comparator->parameter_instruction(i)] = + comparator->parameter_instruction(i ^ 1); + } + HloInstruction* cloned_root = nullptr; + for (HloInstruction* inst : instructions_postorder) { + if (inst->operand_count() == 0) { + continue; + } + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->operands()) { + new_operands.push_back(replace(operand)); + } + auto new_instruction = + inst->CloneWithNewOperands(inst->shape(), new_operands); + replacements[inst] = new_instruction.get(); + if (inst == old_root) { + cloned_root = new_instruction.get(); + } + comparator->AddInstruction(std::move(new_instruction)); + } + CHECK_NE(cloned_root, nullptr); + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + HloInstruction* same = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + HloInstruction* tie_breaker = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kLt, + comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1))); + HloInstruction* new_root = + comparator->AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, + old_root)); + comparator->set_root_instruction(new_root); + + return expanded_sort; +} + +bool StableSortExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSort && + Cast(instruction)->is_stable(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.h b/tensorflow/compiler/xla/service/stable_sort_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..31b6fd92d25370218017c58072f1aa5e64df00c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a62d953e6e8fa2f3c1ecfd9e4a7900eee74f9dca --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -0,0 +1,358 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using StableSortExpanderTest = HloTestBase; + +// Checks whether 'a' and 'b' are roots of equivalent computations, except that +// parameters 2 * i and 2 * i + 1 are switched. +bool IsSameComputationExceptParams(const HloInstruction* a, + const HloInstruction* b) { + if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) { + return false; + } + if (a->opcode() == HloOpcode::kParameter) { + // Check that parameters were switched. + return a->parameter_number() == (b->parameter_number() ^ 1); + } + // If the operation has no operands, it should actually be the same. + if (a->operand_count() == 0) { + return a == b; + } + // Otherwise recursively compare all operands. + for (int64 i = 0; i < a->operand_count(); ++i) { + if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) { + return false; + } + } + return true; +} + +// Check that the comparison computation has been modified to add a tie breaker +// using 'iota_parameter'. +void CheckComputationHasTieBreaker(const HloInstruction* root, + int64 iota_parameter) { + // With the tie breaker, the root instruction should be + // Select(Eq(Comp(), CompReverse()), Lt(), Comp()) + // with Comp() being the original comparison function, and CompReverse() being + // the copied comparison function where the parameters are reversed. Lt() is + // the tie breaker comparison using the Iota operand. + ASSERT_EQ(root->opcode(), HloOpcode::kSelect); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + + // Check that the tie breaker instruction is correct. + EXPECT_THAT(root->operand(1), + GmockMatch(m::Lt(m::Parameter(iota_parameter * 2), + m::Parameter(iota_parameter * 2 + 1)))); + EXPECT_EQ(root->operand(2), root->operand(0)->operand(0)); + + // Check that Comp() and CompReverse() are equivalent except that + // CompReverse() has reversed parameters. + EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0), + root->operand(0)->operand(1))); +} + +TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortReuseIotaOperandComplicatedComparison) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + max = u32[] constant(2147483647) + zero = s32[] constant(0) + lhs.signed = s32[] bitcast-convert(p.0.lhs) + lhs.unsigned = u32[] bitcast-convert(p.0.lhs) + lhs.flipped = u32[] subtract(max, lhs.unsigned) + lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) + lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) + rhs.signed = s32[] bitcast-convert(p.0.rhs) + rhs.unsigned = u32[] bitcast-convert(p.0.rhs) + rhs.flipped = u32[] subtract(max, rhs.unsigned) + rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) + rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) + ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Tuple( + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0), + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, HonorIsStableFlag) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=false + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie()); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortDontReuseIotaOperandWrongDimension) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=0 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + ROOT neg = s32[64,8732]{1,0} negate(sort) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/1); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 551602613927671c9c37a4e8685df76d6a4ca9cf..6f61fc44166298e86a88dfc4f0ce8526d65ffd02 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1072,7 +1072,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); TF_ASSERT_OK_AND_ASSIGN( - auto* sort, MakeSortHlo(keys_shape, {keys}, 0, &builder, module_.get())); + auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, + &builder, module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); @@ -1094,7 +1095,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), - {keys, values}, 0, &builder, module_.get())); + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 93d630b8f736f6c41d4014ef6415e80eac5a65ec..94854047e530babe2234381a615aeb805f0d5933 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -147,7 +147,14 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { return false; } if (LayoutUtil::IsDenseArray(lhs)) { - if (lhs.layout() != rhs.layout()) { + Layout::Equal equal; + if (ignore_tiles_in_layout_) { + equal.IgnoreTiles(); + } + if (ignore_element_size_in_layout_) { + equal.IgnoreElementSize(); + } + if (!equal(lhs.layout(), rhs.layout())) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 1d594904e0b9e6f1779674e75b41b7a597788bac..78cea83c6d71e5965f10cd3a917ffccabd630462 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -146,10 +146,10 @@ class Shape { // // Examples: // - // - Comparing two shapes ignoring they layout difference: + // - Comparing two shapes ignoring their layout difference: // Equal().IgnoreLayout()(shape1, shape2); // - // - Comparing two shapes ignoring they layout and element type difference: + // - Comparing two shapes ignoring their layout and element type difference: // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); class Equal { public: @@ -161,6 +161,14 @@ class Shape { ignore_layout_ = true; return *this; } + Equal& IgnoreTilesInLayout() { + ignore_tiles_in_layout_ = true; + return *this; + } + Equal& IgnoreElementSizeInLayout() { + ignore_element_size_in_layout_ = true; + return *this; + } Equal& IgnoreElementType() { ignore_element_type_ = true; return *this; @@ -174,8 +182,10 @@ class Shape { return *this; } - public: + private: bool ignore_layout_ = false; + bool ignore_tiles_in_layout_ = false; + bool ignore_element_size_in_layout_ = false; bool ignore_element_type_ = false; bool ignore_fp_precision_ = false; bool ignore_dynamic_dimension_ = false; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index e6273c4e7f8ed3f8feab0ecd540ad1081f653c8b..d045fc7a9e291258640eca75166e116cf7390a7b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -1256,6 +1257,43 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& input_shape, const Shape& output_shape) { CHECK(input_shape.IsArray()); CHECK(output_shape.IsArray()); + // Removing trivial dimensions from the shape simplifies the alignment + // algorithm since ones can go in any position. + if (HasDegenerateDimensions(input_shape) || + HasDegenerateDimensions(output_shape)) { + auto simple_output_shape = + AlignLayouts(DropDegenerateDimensions(input_shape), + DropDegenerateDimensions(output_shape)); + if (!simple_output_shape) { + return absl::nullopt; + } + + auto layout = simple_output_shape->layout().minor_to_major(); + // For each one sized dimension in the output, increment the dimension + // numbers in layout that are more minor than the one. + absl::InlinedVector dim_map; + dim_map.reserve(simple_output_shape->rank()); + for (int64 i = 0; i < output_shape.rank(); ++i) { + if (output_shape.dimensions(i) != 1) { + dim_map.push_back(i); + } + } + for (int64& d : layout) { + d = dim_map[d]; + } + + // Add the ones in descending order to the layout. Descending layouts tend + // to reduce the number of copies inserted in layout assignment. + for (int64 i = output_shape.rank() - 1; i >= 0; --i) { + if (output_shape.dimensions(i) == 1) { + layout.push_back(i); + } + } + Shape output_shape_with_layout = output_shape; + *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() = + layout; + return output_shape_with_layout; + } int64 input_rank = input_shape.rank(); int64 output_rank = output_shape.rank(); @@ -1304,10 +1342,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product != output_dimension_product) { return absl::nullopt; } + // We also need to store an end element so that we know where the last // alignment part ends. alignment.push_back({input_rank, output_rank}); - // Now check if the physical layout can potentially be aligned to the output // shape by changing the physical layout of the output shape. We need to check // that all dimension numbers that belong to the same alignment part appear @@ -1319,40 +1357,23 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 i = 0; i < input_rank;) { int64 current_dimension_number = input_dimension_numbers[i]; - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(current_dimension_number) == 1) { - ++i; - continue; - } - - // Calculate the number of non-trivial dimension bounds in the input shape - // belonging to the current alignment part. + // Trivial dimensions are stripped. + CHECK_NE(input_shape.dimensions(current_dimension_number), 1); const int64 current_alignment_index = dimension_to_alignment_index[current_dimension_number]; // Because of the special end element that we added, we can be sure that // 'current_alignment_index' is < alignment.size() - 1. CHECK_LT(current_alignment_index, alignment.size() - 1); - int64 num_non_trivial_dimensions_in_alignment_part = 0; - for (int64 j = alignment[current_alignment_index].first; - j < alignment[current_alignment_index + 1].first; ++j) { - if (input_shape.dimensions(j) != 1) { - ++num_non_trivial_dimensions_in_alignment_part; - } - } // Check that the following 'num_non_trivial_dimensions_in_alignment_part' // dimension numbers (ignoring dimension numbers with dimension bound 1) are // in descending order and belong to the current alignment part. - for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + for (int64 j = 0; j < alignment[current_alignment_index + 1].first - + alignment[current_alignment_index].first; ++i, ++j) { if (i == input_rank) { return absl::nullopt; } - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { - --j; - continue; - } // If the current dimension number belongs to a different alignment part, // or the dimension numbers are not in descending order, we can return // early. @@ -1363,22 +1384,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } current_dimension_number = input_dimension_numbers[i]; } - // The output dimension numbers that belong to the current alignment part - // need to appear in the same descending order as in the input. Again, we - // can skip dimensions with a bound of 1. + // need to appear in the same descending order as in the input. for (int64 j = alignment[current_alignment_index + 1].second - 1; j >= alignment[current_alignment_index].second; --j) { - if (output_shape.dimensions(j) != 1) { - output_layout.push_back(j); - } - } - } - // Now add all the dimensions with dimension bound 1 at the end of - // 'output_layout'. - for (int64 i = 0; i < output_rank; ++i) { - if (output_shape.dimensions(i) == 1) { - output_layout.push_back(i); + output_layout.push_back(j); } } CHECK_EQ(output_layout.size(), output_rank); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 126ae58293d12182e9b6e30f779f681829729526..020b062f6b1b032bab958772d3a6a1e35daee38b 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -761,8 +761,15 @@ TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { auto aligned_shape = ShapeUtil::AlignLayouts( input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); EXPECT_TRUE(aligned_shape); - EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), - ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithAllTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 1, 1, 1}, {0, 1, 3, 2}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 1, 1, 1, 1})); + EXPECT_TRUE(aligned_shape); EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index db1c9274690583326b8a8d36413d725c14007aa3..a67aa6ebfe2c21c6b701de67e608cac12cd6ccbf 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1146,7 +1146,7 @@ xla_test( xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], - shard_count = 40, + shard_count = 31, tags = [ "optonly", ], diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index cad43d1b5547d74701760fa623e50466fc15c263..4687ed61a7de91bc1bce0efeadf1965ad7d52d55 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -172,8 +172,10 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { const Shape& r2f32_dim0_major = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); - b.AddInstruction(HloInstruction::CreateCustomCall( + auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall( r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + b.AddInstruction( + custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); module->AddEntryComputation(b.Build()); ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); @@ -182,7 +184,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); Literal result = ExecuteAndTransfer(std::move(module), {&argument}); - LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); + LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); } XLA_TEST_F(CustomCallTest, TupleOutput) { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 7a165e69f3264d828aac1fe2e23edf695e339eaf..5d910a193dc1d2736280a4a4e81cc65824f5afca 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1188,6 +1188,8 @@ std::vector GetEinsumTestCases() { p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, p{v{6}, v{6, 7}, "b,bc->c"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, p{v{77}, v{77}, "a,a->a"}, p{v{77}, v{77, 55}, "a,ab->ba"}, p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, @@ -1265,11 +1267,11 @@ ENTRY %test { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } -XLA_TEST_F(DotOperationTextTest, CachingBug) { +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { // Tests for a caching bug in the XLA CPU backend. absl::string_view hlo_string = R"( -HloModule CachingBug +HloModule CpuTiledDotEmitterCachingBug ENTRY main { lhs = f32[20,40] parameter(0) @@ -1286,5 +1288,45 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs_0 = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(1) + lhs_1 = f32[1,40] parameter(2) + rhs_1 = f32[20,40] parameter(3) + + dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + dot_0_reshaped = f32[20] reshape(dot_0) + dot_1_reshaped = f32[20] reshape(dot_1) + + ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuIntegerDotCodegen)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[1,2,2] parameter(0) + arg1 = s32[1,2,1] parameter(1) + ROOT dot = s32[1,2,1] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d9d54fd2556be01a56afd36c13fcc8cf2184ece8..0151981ef16aabe9e363bc4d7f9ba96d4a1f170f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -205,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 78bdd336e0a96999440f6331a965987ee0cb6bf2..3c2bcbb5df5ce94dd37f63d0c0e609f3ad2b60aa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -173,6 +173,11 @@ class HloTestBase : public ::testing::Test { Literal ExecuteAndTransfer(std::unique_ptr module, absl::Span arguments); + // Executes the given module on multiple replicas. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. diff --git a/tensorflow/compiler/xla/tests/plugin.bzl b/tensorflow/compiler/xla/tests/plugin.bzl index 8a5d91363b619c6b214a96ad96e92742e3052541..107869fe59d43d0a9a3e2b14af2c09e4906d9f15 100644 --- a/tensorflow/compiler/xla/tests/plugin.bzl +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -33,4 +33,3 @@ # } plugins = {} - diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 95c89b0ba6f29c453abab88e29bca13ee006455a..67d2258928f75c078588c9425359f9468f4463ed 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -238,6 +238,79 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, return std::move(literal); } +template +void PopulateWithRandomIntegralDataWithBounds(Literal* literal, + std::minstd_rand0* engine, + IntT min, IntT max) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::uniform_int_distribution generator(min, max); + for (IntT& value : literal->data()) { + value = generator(*engine); + } +} + +// Same as MakeFakeLiteralInternal but generates random numbers in the given +// range [min, max]. Currently this works only for INT types. +StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, + std::minstd_rand0* engine, + int64 min, int64 max) { + if (shape.IsTuple()) { + std::vector elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max)); + elements.push_back(std::move(element)); + } + return LiteralUtil::MakeTupleOwned(std::move(elements)); + } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + Literal literal(shape); + switch (shape.element_type()) { + case S8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + default: + return Unimplemented( + "Unsupported type for fake random literal generation with bounds: %s", + ShapeUtil::HumanString(shape)); + } + return std::move(literal); +} + enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. @@ -297,6 +370,10 @@ std::vector FindConstrainedUses( if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); + } else if ((opcode == HloOpcode::kGather || + opcode == HloOpcode::kScatter) && + op_num == 1) { + constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = instruction->fused_parameter(op_num); @@ -356,6 +433,22 @@ StatusOr CreateLiteralForConstrainedUses( } break; } + case HloOpcode::kGather: + case HloOpcode::kScatter: { + const Shape& operand_shape = use->operand(0)->shape(); + if (use->operand(1) == ¶m) { + auto index_map = + use->opcode() == HloOpcode::kGather + ? use->gather_dimension_numbers().start_index_map() + : use->scatter_dimension_numbers() + .scatter_dims_to_operand_dims(); + for (const auto dim_in_operand : index_map) { + index_bound = + std::min(index_bound, operand_shape.dimensions(dim_in_operand)); + } + } + break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = true; @@ -385,8 +478,8 @@ StatusOr CreateLiteralForConstrainedUses( return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { - return MakeRandomIndex(index_bound, engine) - .Reshape(param.shape().dimensions()); + return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1, + index_bound); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 321c3fb2df6f0beccded4617e91eff69c2bce2ea..f68ee04565f3898bd3db455e3e102bc2edb6255a 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -92,12 +92,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -122,12 +123,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 7); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -252,5 +254,77 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { << ShapeUtil::HumanString(args[1].shape()); } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { + auto module = ParseHloString(R"( + HloModule Test + +ENTRY %module(paramater.0: f32[200,100,300], parameter.1: s32[10,2]) -> + f32[10,300] { + %parameter.0 = f32[200,100,300] parameter(0) + %parameter.1 = s32[10,2] parameter(1) + ROOT gather = f32[10,300] gather(f32[200,100,300] %parameter.0, + s32[10,2] %parameter.1), + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1,300} +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { + auto module = ParseHloString(R"( + HloModule Test + +scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + ROOT rhs = f32[] parameter(1) +} + +ENTRY main { + operand = f32[200,100,300] parameter(0) + indices = s32[10,2] parameter(1) + updates = f32[10,300] parameter(2) + ROOT scatter = f32[200,100,300] scatter(operand, indices, updates), + to_apply=scatter_update, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 + } +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 52fee4770ab940741723514d742e998b25765f24..ebd4bb1e42c9d1dc1f72a75514e916a2d900c30e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -177,26 +177,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "dumped_computation_to_tf_graphdef", - srcs = ["dumped_computation_to_tf_graphdef.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_binary( name = "hlo_proto_to_json", srcs = ["hlo_proto_to_json.cc"], diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc deleted file mode 100644 index f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* -// -// Dumps a tensorflow GraphDef in text format for a snapshot computation. The -// dumped graph is an HLO computation with HLO instructions as nodes and can be -// visualized on Tensorboard. Upload the dumped files on Tensorboard. -// -// some_binary_snapshot_proto is obtained by serializing the SessionModule from -// ServiceInterface::SnapshotComputation to disk. - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -using tensorflow::Env; - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - debug_options.set_xla_hlo_dump_as_graphdef(true); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index c743dfd32b3a1327af480cf32ae3cdeb08ee814e..cda2d7c7c6b2403868f6d01a485753fa29a8d95f 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -30,6 +30,11 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = **kwargs ) +def xla_py_proto_library(**kwargs): + # Note: we don't currently define a proto library target for Python in OSS. + _ignore = kwargs + pass + def xla_py_grpc_library(**kwargs): # Note: we don't currently define any special targets for Python GRPC in OSS. _ignore = kwargs diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 92834dbb02cdcd6383ceec3ffd079834b163ee6a..925fcbf88c1e8dd81ab1339d292e05eae52e0d13 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -15,11 +15,11 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; - package xla; +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { // Where and when the reduce-precision operations will be added. @@ -72,8 +72,7 @@ message DebugOptions { // Path to dump HLO graphs to. string xla_hlo_graph_path = 4; - // Dump HLO graphs as TensorFlow GraphDefs. - bool xla_hlo_dump_as_graphdef = 5; + reserved 5; // Was xla_hlo_dump_as_graphdef // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to // dump *all* HLO modules. @@ -171,9 +170,7 @@ message DebugOptions { // HLO graph. bool xla_hlo_graph_sharding_color = 92; - // Prefix the name scopes of the TF graph exports with "devX" device - // assignments, if available. - bool xla_hlo_tfgraph_device_scopes = 93; + reserved 93; // Was xla_hlo_tfgraph_device_scopes // If true, the GPU backend is free to use cudnn for HLO batch normalization // ops. @@ -234,7 +231,23 @@ message DebugOptions { // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. bool xla_allow_scalar_index_dynamic_ops = 107; - // Next id: 108 + enum StepMarkerLocation { + // Generate step mark at each iteration of top level while loop, which + // is assumed to be a training loop. This is the default. + STEP_MARK_AT_ENTRY = 0; + // Generate step mark at program entry. This handles the case where each + // step are done by one or multiple programs execution. Only the first + // program will be tagged for generating step mark at program entry. + STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; + // No step mark. + STEP_MARK_NONE = 2; + } + // Option to emit a target-specific marker to indicate the start of a training + // step. The location of the marker (if any) is determined by the option + // value. + StepMarkerLocation xla_step_marker_location = 108; + + // Next id: 109 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -306,8 +319,7 @@ message TransferToInfeedRequest { DeviceHandle device_handle = 3; } -message TransferToInfeedResponse { -} +message TransferToInfeedResponse {} message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this @@ -326,8 +338,7 @@ message ResetDeviceRequest { DeviceHandle device_handle = 1; } -message ResetDeviceResponse { -} +message ResetDeviceResponse {} message ComputationGraphStatsRequest { HloModuleProto computation = 1; @@ -350,8 +361,7 @@ message UnregisterRequest { repeated GlobalDataHandle data = 1; } -message UnregisterResponse { -} +message UnregisterResponse {} message CompileRequest { // The graph to be compiled. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 4e127356a9fa7c921386c13c5ecd64af5ab19ed3..226299a7186ef0acb41f6d01fdeffeee06f13d4d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -624,3 +624,15 @@ message PrecisionConfig { // Next: 2 } + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 25f2640e35af5f65eab25dc60c44e3ed7ce4e512..0173b8bb064c7b2fb8a0df018204515b24cfa718 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -218,7 +218,6 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", - "//tensorflow/contrib/tpu:all_ops", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index f0b1c92cf7e4b760381da38febd9682ce2a4f27c..5608e7ddafa25757484d8c845c8c84a5691e143c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -73,8 +73,7 @@ cc_binary( "-z defs", "-s", "-Wl,--gc-sections", - "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - "$(location {})".format(LINKER_SCRIPT), + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e6fda9e61757f1441b3691c2a3d57c6f1a5a0d42..d9fce6e09f47ab05074f0b4c03dd8e672ed3d2ce 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -335,6 +335,17 @@ grpc::Status BigtableTestClient::ReadModifyWriteRow( return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "ReadModifyWriteRow not implemented."); } +std::unique_ptr> +BigtableTestClient::AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::unique_ptr< grpc::ClientReaderInterface> BigtableTestClient::ReadRows( diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index 8e1326f2ce841368ea81fc7194a0588e5d6cd637..63d59b32dd17a2f58d3413932b69f4d704c84e48 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -46,6 +46,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { google::bigtable::v2::ReadModifyWriteRowRequest const& request, google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + std::unique_ptr> + AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) override; + std::unique_ptr< grpc::ClientReaderInterface> ReadRows(grpc::ClientContext* context, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 47d910d42a27db4b857eeb12209dfbb429dd1be2..5a8b2ba9caf0a9813cb5b3409b8a0dc3de0a45d7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -399,8 +399,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def testQuantileRegression(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -413,7 +413,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=12, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -428,31 +428,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper >= 0.92) self.assertTrue(frac_below_upper <= 0.98) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - # Multi-dimensional quantile regression. def testQuantileRegressionMultiDimLabel(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -467,7 +448,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): quantiles=[0.95], learner_config=learner_config, label_dimension=2, - num_trees=100, + num_trees=18, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -490,35 +471,6 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_both_below_upper >= 0.91) self.assertTrue(frac_both_below_upper <= 0.99) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - label_dimension=2, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) - class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -712,11 +664,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) - # One dimensional quantile regression. - def testQuantileRegression(self): + # Quantile regression in core is the same as in non core estimator, so we + # just check that it does not fail. + def testQuantileRegressionDoesNotThroughException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 + learner_config.constraints.max_tree_depth = 1 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -731,112 +684,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_upper.train(input_fn=train_input_fn, steps=1000) - result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper >= 0.92) - self.assertTrue(frac_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - - # Multi-dimensional quantile regression. - def testQuantileRegressionMultiDimLabel(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) - - train_input_fn, test_input_fn, y = _quantile_regression_input_fns( - two_dimension=True) - y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) - - # 95% percentile. - model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.95], - learner_config=learner_config, - num_trees=100, - label_dimension=2, + num_trees=1, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) model_upper.train(input_fn=train_input_fn, steps=1000) result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - count_below_upper = np.count_nonzero(upper > y, axis=0) - count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) - frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) - frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) - frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper_0 >= 0.92) - self.assertTrue(frac_below_upper_0 <= 0.98) - self.assertTrue(frac_below_upper_1 >= 0.92) - self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.91) - self.assertTrue(frac_both_below_upper <= 0.99) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index c3685b54e201f73039f6623443c67ba2b217a51e..ad6ff0a861af896ef0dd254bd47752d76378d63a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 0c319cc9bd1f720eb404a9da05227c5807ec874f..aff7105e94729942efc6e3e9d3ae23b733e8f5ed 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index ad1191d41236e71008bff8c8a7fbd42c16e3f9c5..2a0a206d97bbf01ac382531df31a66d429842bbb 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 99ed4959fad9699f265183d71a1f3b609d7e6d30..7b3df962542a656af8052e9f2eae6e83744411f2 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -27,7 +27,7 @@ Managing dependencies: @@NoDependency @@split_dependency -Checkpointable data structures: +Trackable data structures: @@List @@Mapping @@UniqueNameTracker @@ -49,17 +49,16 @@ from tensorflow.contrib.checkpoint.python.python_state import NumpyState from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint -from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase -from tensorflow.python.training.checkpointable.data_structures import List -from tensorflow.python.training.checkpointable.data_structures import Mapping -from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import capture_dependencies -from tensorflow.python.training.checkpointable.util import list_objects -from tensorflow.python.training.checkpointable.util import object_metadata - +from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase +from tensorflow.python.training.tracking.data_structures import List +from tensorflow.python.training.tracking.data_structures import Mapping +from tensorflow.python.training.tracking.data_structures import NoDependency +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import capture_dependencies +from tensorflow.python.training.tracking.util import list_objects +from tensorflow.python.training.tracking.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 4e529322c7c76797938468b405cd175609dc0a73..cd9c94c9bd72d398d183d3f3d485ab48cb2fd617 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -12,7 +12,7 @@ py_library( ":python_state", ":split_dependency", ":visualize", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -22,8 +22,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -36,8 +36,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -47,7 +47,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", "//third_party/py/numpy", "@six_archive//:six", ], @@ -64,7 +64,7 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) @@ -76,7 +76,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) @@ -89,8 +89,8 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -101,8 +101,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -118,6 +118,7 @@ tf_py_test( "//tensorflow/python/eager:test", "//tensorflow/python/keras:engine", "//tensorflow/python/keras:layers", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], + tags = ["no_oss"], # b/124472244 ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 97936d9e9dfd5d6e62fdf8312707a276b63e1267..a25d51980ea760dfb7f323497a397fbd94fd5f23 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -1,4 +1,4 @@ -"""Checkpointable data structures.""" +"""Trackable data structures.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.tracking import base as trackable_lib +from tensorflow.python.training.tracking import data_structures -class UniqueNameTracker(data_structures.CheckpointableDataStructure): - """Adds dependencies on checkpointable objects with name hints. +class UniqueNameTracker(data_structures.TrackableDataStructure): + """Adds dependencies on trackable objects with name hints. Useful for creating dependencies with locally unique names. @@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): def __init__(self): super(UniqueNameTracker, self).__init__() - self._maybe_initialize_checkpointable() + self._maybe_initialize_trackable() self._name_counts = {} @property def _values(self): return [dep.ref for dep in self._checkpoint_dependencies] - def track(self, checkpointable, base_name): - """Add a dependency on `checkpointable`. + def track(self, trackable, base_name): + """Add a dependency on `trackable`. Args: - checkpointable: An object to add a checkpoint dependency on. + trackable: An object to add a checkpoint dependency on. base_name: A name hint, which is uniquified to determine the dependency name. Returns: - `checkpointable`, for chaining. + `trackable`, for chaining. Raises: - ValueError: If `checkpointable` is not a checkpointable object. + ValueError: If `trackable` is not a trackable object. """ - if not isinstance(checkpointable, checkpointable_lib.Checkpointable): + if not isinstance(trackable, trackable_lib.Trackable): raise ValueError( - ("Expected a checkpointable value, got %s which does not inherit " - "from CheckpointableBase.") % (checkpointable,)) + ("Expected a trackable value, got %s which does not inherit " + "from tf.track.Trackable.") % (trackable,)) def _format_name(prefix, number): if number > 0: @@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - self._track_value(checkpointable, name=candidate) - return checkpointable + self._track_value(trackable, name=candidate) + return trackable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index a2d453ec6eb3dcf9aba4c52fe866756a92673c63..bace21939602666aa48a05d2abfe05ae6aae41e2 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,9 +26,9 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import data_structures -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import data_structures +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util class UniqueNameTrackerTests(test.TestCase): @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.AutoCheckpointable() + restore_slots = tracking.AutoTrackable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.AutoCheckpointable): + class SlotManager(tracking.AutoTrackable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 969c90c78871ebff02b360f8f09623df56c9c077..737a6c30c1dce65dd7638ee52e6c26a8a40f8321 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -23,7 +23,7 @@ import six import numpy -from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.tracking import base # pylint: disable=g-import-not-at-top try: @@ -34,8 +34,8 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.Checkpointable): - """A checkpointable object whose NumPy array attributes are saved/restored. +class NumpyState(base.Trackable): + """A trackable object whose NumPy array attributes are saved/restored. Example usage: @@ -72,7 +72,7 @@ class NumpyState(base.Checkpointable): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency - exists. We cheat slightly by creating a checkpointable object for `name` if + exists. We cheat slightly by creating a trackable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. @@ -85,7 +85,7 @@ class NumpyState(base.Checkpointable): value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) - new_reference = base.CheckpointableReference(name=name, ref=value) + new_reference = base.TrackableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) @@ -101,7 +101,7 @@ class NumpyState(base.Checkpointable): def __setattr__(self, name, value): """Automatically wrap NumPy arrays assigned to attributes.""" # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making - # ndarrays checkpointable natively and using standard checkpointable list + # ndarrays trackable natively and using standard trackable list # tracking. if isinstance(value, (numpy.ndarray, numpy.generic)): try: @@ -110,19 +110,19 @@ class NumpyState(base.Checkpointable): return except AttributeError: value = _NumpyWrapper(value) - self._track_checkpointable(value, name=name, overwrite=True) + self._track_trackable(value, name=name, overwrite=True) elif (name not in ("_setattr_tracking", "_update_uid") and getattr(self, "_setattr_tracking", True)): - # Mixing restore()-created attributes with user-added checkpointable + # Mixing restore()-created attributes with user-added trackable # objects is tricky, since we can't use the `_lookup_dependency` trick to # re-create attributes (we might accidentally steal the restoration for - # another checkpointable object). For now `NumpyState` objects must be + # another trackable object). For now `NumpyState` objects must be # leaf nodes. Theoretically we could add some extra arguments to # `_lookup_dependency` to figure out whether we should create a NumPy # array for the attribute or not. raise NotImplementedError( ("Assigned %s to the %s property of %s, which is not a NumPy array. " - "Currently mixing NumPy arrays and other checkpointable objects is " + "Currently mixing NumPy arrays and other trackable objects is " "not supported. File a feature request if this limitation bothers " "you.") % (value, name, self)) @@ -130,7 +130,7 @@ class NumpyState(base.Checkpointable): @six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.Checkpointable): +class PythonStateWrapper(base.Trackable): """Wraps a Python object for storage in an object-based checkpoint.""" @abc.abstractmethod diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 45494351ff4e6c8c75634d8563c3fb63c6089036..40d8fe836402c8b6c8240ef9f665b753c54ede0d 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -26,7 +26,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import util class NumpyStateTests(test.TestCase): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3e9700ad74618e24843181d169f3fb39ac96bff6..d7b02b538909305b14e638761bd8ba67a948d2b4 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -21,7 +21,7 @@ import functools from tensorflow.python.ops import control_flow_ops from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.Checkpointable): +class _SplitDependency(trackable.Trackable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, @@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.Checkpointable): return control_flow_ops.no_op() def _gather_saveables_for_checkpoint(self): - """Looks to Checkpointable like a regular variable.""" + """Looks to Trackable like a regular variable.""" return { - checkpointable.VARIABLE_VALUE_KEY: + trackable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, save_callback=self._save, @@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes, may return `None`). Returns: - A dictionary mapping from names to Checkpointable objects. If one is + A dictionary mapping from names to Trackable objects. If one is reachable from an object as a dependency, the others should be too; adding dependencies on some but not all of the objects will result in errors. """ diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 664a4e76ab31bf31c7a57924e4af866f2d746804..9bc01059481ff69064e3f9c682a764146b79a250 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,9 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util def _split_variable_closure(variable): @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.Checkpointable): +class SaveTensorSlicesAsDeps(base.Trackable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.Checkpointable): consume_restore_buffer_fn=_combine_variable_closure( self.combined)) for name, dep in split_dependencies.items(): - self._track_checkpointable(dep, name=name) + self._track_trackable(dep, name=name) -class HasRegularDeps(tracking.AutoCheckpointable): +class HasRegularDeps(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.AutoCheckpointable): +class OnlyOneDep(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index bac071c4cff383f60b707b6e42c13faf5e0ac948..faf90f018476b3c70a7bfa1346a5b590edbbddcd 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.tracking import util as trackable_utils def dot_graph_from_checkpoint(save_path): @@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - object_graph = checkpointable_utils.object_metadata(save_path) + object_graph = trackable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() graph = 'digraph {\n' @@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path): slot_ids.add(slot_reference.slot_variable_node_id) for node_id, node in enumerate(object_graph.nodes): if (len(node.attributes) == 1 - and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY): if node_id in slot_ids: color = 'orange' tooltip_prefix = 'Slot variable' diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 583e3bc442893d825c337d73fb999d1e586738a1..98a22d573fdb6172cde100df461d9ae520c2c483 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils try: import pydot # pylint: disable=g-import-not-at-top @@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = resource_variable_ops.ResourceVariable(12) - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) optimizer.minimize(functools.partial(model, input_value)) checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 379b530361f42279a8d489282bf1b35f08ba74cf..ff48ba4de50c06610e1a83f0a98b1e4238d5c889 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d0d93bdab84f2befb425e9a991d17dc78c195c6d) +set(GRPC_TAG 3dacd1afc451803fbbc4d01c53cbaf026aa9d06b) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 8b6395304bb81476775e5a2d8f2ec7876035778c..3d86ab9abbb4cc90c406edc6237c0d2abe440122 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -72,7 +72,7 @@ tensorflow/python/tools tensorflow/python/tools/api tensorflow/python/tools/api/generator tensorflow/python/training -tensorflow/python/training/checkpointable +tensorflow/python/training/tracking tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index ca92c31236a7a3882415834eb32a994a120b6d2d..403f30909520dc5cd5f5919af843291fe1400b91 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) -class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): +class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase): def _VerifyCheckpoint( self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, @@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with ops.device("gpu:0"): cudnn_layer = cudnn_cell_fn() - cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer) status = cudnn_checkpoint.restore(checkpoint_path) inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) @@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() - restore_layer_checkpoint = checkpointable_utils.Checkpoint( + restore_layer_checkpoint = trackable_utils.Checkpoint( cell=restore_layer) status = restore_layer_checkpoint.restore(second_save_path) current_state = restore_layer.zero_state(1, dtypes.float32) @@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(restore_layer_output), self.evaluate(cudnn_output)[-1, -1:, ...]) - def _CheckpointableSingleCellUnidirectionalTestTemplate( + def _TrackableSingleCellUnidirectionalTestTemplate( self, single_cell_fn, cudnn_cell_fn): # Single-layer cuDNN cells with object-based checkpointing should be # checkpoint compatible with either single CudnnCompatible cells or @@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testLSTMCheckpointableSingleLayer(self): + def testLSTMTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( @@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testGRUCheckpointableSingleLayer(self): + def testGRUTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION with self.assertRaises(NotImplementedError): # TODO(allenl): Implement object-based saving for GRUs and other cells. - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), cudnn_cell_fn=functools.partial( cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, direction=direction, name="awesome_gru")) - def _CheckpointableMultiLayerTestTemplate( + def _TrackableMultiLayerTestTemplate( self, single_cell_fn, cudnn_cell_fn, num_layers): def _MultiCellFn(): @@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): num_units = 2 num_layers = 3 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableMultiLayerTestTemplate( + self._TrackableMultiLayerTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 86ad8ae8073714657c78badb1e0b4a6d8c8ed5f0..1cb477716dfc6a9cc793939059784f9d89bcdd8a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -518,8 +518,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) - self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access - checkpointable=self, dtype=self._plain_dtype) + self._saveable._add_trackable_dependencies( # pylint: disable=protected-access + trackable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index f36e8d5022bc7e3f8268a161089153e5510dffc6..7d848e2ec2d99cd2a78ff3e813207c0cd5bb97cf 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking as checkpointable_lib +from tensorflow.python.training.tracking import tracking as trackable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) - def _checkpointable_save(self, save_buffer): + def _trackable_save(self, save_buffer): weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) - def _checkpointable_restore(self, restore_buffer): + def _trackable_restore(self, restore_buffer): tensors = [ array_ops.identity(restore_buffer[name]) for name in self._param_names ] @@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): restored_shapes=None # Unused ) - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. + def _add_trackable_dependencies(self, trackable, dtype): + """Add canonical weight dependencies to `trackable`. When saving or restoring, converts to or from the opaque buffer format. Weights are saved and loaded in the configuration expected by cuDNN-compatible cells. Args: - checkpointable: An object inheriting from `CheckpointableBase` to add + trackable: An object inheriting from `Trackable` to add dependencies too (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) + fill_save_buffer_fn=self._trackable_save, + consume_restore_buffer_fn=self._trackable_restore) + self._trackable_track_params(trackable, split_dependencies) - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Tracks parameters in a canonical configuration.""" return # NotImplementedError raised by the Layer. @@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_weights_names.append(prefix + "/kernel") tf_bias_names.append(prefix + "/bias") - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" biases = [] weights = [] @@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # wrapping. kernel, = weights # pylint: disable=unbalanced-tuple-unpacking bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access + trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.AutoCheckpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell = trackable_lib.AutoTrackable() + trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 44ecc8c4286d594e40378e6811a085ade73cea84..63879968bfbd06d7005e57724cbc4dff1dbcbb5c 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -800,6 +800,6 @@ tf_xla_py_test( ":tpu_strategy", "//tensorflow/compiler/tests:xla_test", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py index aa5b9f57b8a5bc12ee94399ec1fc5a55177a5b5d..eadf7233f2ae5ee50b71836ebfcc895163124ac2 100644 --- a/tensorflow/contrib/distribute/python/checkpointing_test.py +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -30,15 +30,15 @@ from tensorflow.python.platform import test from tensorflow.python.training import adam as adam_v1 from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + super(NonLayerTrackable, self).__init__() + self.a_variable = trackable_utils.add_variable( self, name="a_variable", shape=[]) @@ -49,8 +49,8 @@ class Subclassed(training.Model): super(Subclassed, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -76,7 +76,7 @@ class TrainingCheckpointTests(xla_test.XLATestCase): with strategy.scope(): model = Subclassed() optimizer = adam_v1.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(checkpoint_management.latest_checkpoint( diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index acbe4677b401cbea4fd0ec415415f25c920e68e4..ee7640dd1cea15e62ae9912ebedbd853778364a6 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -410,6 +410,7 @@ class DistributedCollectiveAllReduceStrategyTest( num_gpus=num_gpus, use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( combinations.combine( @@ -418,7 +419,8 @@ class DistributedCollectiveAllReduceStrategyTest( required_gpus=1, use_dataset=[True, False], use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy): + def DISABLED_testMakeInputFnIterator(self, num_gpus, use_dataset, + use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -553,7 +555,7 @@ class LocalCollectiveAllReduceStrategy( required_gpus=2, use_dataset=[True, False], use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, use_dataset, use_core_strategy): + def DISABLED_testMakeInputFnIterator(self, use_dataset, use_core_strategy): num_gpus = 2 if use_dataset: fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index 10a58316ec5b3d9d968a88c5c39ff70c277daa65..204f52b034f2366a42fbdab41c467feddb5969a0 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib @@ -214,33 +213,5 @@ class InputIteratorMultiWorkerTest( expected_values, sess) -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 2eca1d1877f36b0dacdf8abef3b3527c5db061ec..77e241974f7c4c27382ab548a202891fdbbc6ba0 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -34,6 +34,8 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -68,6 +70,20 @@ def simple_functional_model(): return model +def simple_subclassed_model(num_labels=_NUM_CLASS): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + return _SimpleMLP(num_labels) + + def simple_multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -1184,5 +1200,127 @@ class TestDistributionStrategyWithDatasets(test.TestCase, atol=1e-4, rtol=1e-4) +class TestRegularizerLoss(test.TestCase, parameterized.TestCase): + class IdentityRegularizer(keras.regularizers.Regularizer): + + def __call__(self, x): + return array_ops.identity(x) + + class AddLayer(keras.layers.Layer): + + def build(self, _): + self.v = self.add_weight( + 'v', (), initializer='ones', + regularizer=TestRegularizerLoss.IdentityRegularizer()) + + def call(self, inputs): + return inputs + self.v + + @staticmethod + def loss_fn(_, y_pred): + return math_ops.reduce_mean(y_pred) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_regularizer_loss(self, distribution): + batch_size = 2 + if not distributed_training_utils.global_batch_size_supported(distribution): + batch_size //= distribution.num_replicas_in_sync + + # Given an input x, which is always 1, and variable v, this model computes + # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is + # initialized to 1. Therefore, this model computes Loss=1+2v, and so the + # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples + # in a batch and then multiplied by the learning rate of 1. As a result, + # the model update for one batch should subtract 2 from v, resulting in v + # being -1. If the regularizer loss is not scaled correctly by number of + # replicas, the variable value will be incorrect when number of replicas + # >1. For e.g. it will be -2 if num replicas = 2. + with distribution.scope(): + x = keras.layers.Input(shape=(), batch_size=batch_size) + y = TestRegularizerLoss.AddLayer()(x) + model = keras.models.Model(inputs=x, outputs=y) + opt = gradient_descent_keras.SGD(1.) + model.compile(opt, loss=TestRegularizerLoss.loss_fn) + model.fit( + x=np.array([[1.], [1.]], dtype=np.float32), + y=np.array([[1.], [1.]], dtype=np.float32), + batch_size=batch_size) + v = model.get_weights()[0] + self.assertEqual(-1.0, v) + + +class TestDistributionStrategyWithKerasModels(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_sequential_model(self, distribution): + with distribution.scope(): + model = simple_sequential_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((20, 10), np.float32) + targets = np.zeros((20, 2), np.float32) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) + + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_functional_model(self, distribution): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) + + # TODO(b/124377929): Remove error assertions once subclassed models + # are supported in DistributedStrategy. + @combinations.generate(all_strategy_combinations_minus_default()) + def test_distribution_strategy_on_subclassed_model(self, distribution): + with distribution.scope(): + model = simple_subclassed_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 2), dtype=np.float32) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.predict(inputs, steps=1) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.evaluate(inputs, targets, steps=1) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_distribution_strategy_one_dimensional(self, distribution): + with distribution.scope(): + inp = keras.layers.Input(shape=(10,)) + out = keras.layers.Dense(3, activation='softmax')(inp) + model = keras.Model(inputs=[inp], outputs=[out]) + model.compile( + optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy'], + ) + + x = np.random.random((64, 10)).astype('float32') + y = np.random.randint(3, size=64) + + model.fit(x, y, epochs=1, steps_per_epoch=2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_utils_test.py b/tensorflow/contrib/distribute/python/keras_utils_test.py index 3e5b422f512c0de5ed90ba73b56448ebbf4fc8a7..36eaee77f21a9f6d62a7c3f616d0126b7a4a8902 100644 --- a/tensorflow/contrib/distribute/python/keras_utils_test.py +++ b/tensorflow/contrib/distribute/python/keras_utils_test.py @@ -414,7 +414,7 @@ class TestDistributionStrategySaveLoadWeights(test.TestCase, @combinations.generate( keras_test_lib.all_strategy_combinations_minus_default()) - def test_save_load_checkpointable(self, distribution): + def test_save_load_trackable(self, distribution): # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 with self.cached_session(): dataset = keras_test_lib.get_dataset(distribution) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index bc0572bb4618967aa13599320218b63a5eec8d10..5ce731816ccefe36c1f876c79589e448f00b86f5 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -116,7 +116,8 @@ class MirroredTwoDeviceDistributionTest( self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) - def testMakeInputFnIteratorWithCallable(self, distribution): + # TODO(b/124344198): Re-enable after fixing this flaky test. + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): def fn(): dataset = dataset_ops.Dataset.range(2).interleave( (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) @@ -1455,7 +1456,7 @@ class MultiWorkerMirroredStrategyTest( self._test_input_fn_iterator( iterator, distribution.extended.worker_devices, expected_values, sess) - def testMakeInputFnIteratorWithCallable(self, distribution): + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): self._configure_distribution_strategy(distribution) def fn(): dataset = dataset_ops.Dataset.range(100) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index fede253d13804087476fef8b7211a6bfe5789906..3de2041ae35775de6df5bca02c0f1d04a9c2f24e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -696,6 +696,7 @@ class ParameterServerStrategyTest( def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( combinations.combine( @@ -704,8 +705,8 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy, - use_dataset): + def DISABLED_testMakeInputFnIteratorDistributed( + self, num_gpus, use_core_strategy, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -732,6 +733,7 @@ class ParameterServerStrategyTest( test_reinitialize=use_dataset, use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. @combinations.generate( combinations.combine( mode=['graph'], @@ -739,8 +741,8 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, - use_dataset): + def DISABLED_testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 8966a9befcd3db4a3f397b319e80f37f84ad236b..d441e4735b64fe1176e77a978d281d46a7b287ab 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -144,7 +144,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 78ab155896cfeda4dd259a8529f4b1f77a12cf0b..48925b1bfacc6b59c210b2fb4b53a9a1a851673f 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class IteratorTest(test.TestCase): @@ -238,7 +238,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual([1, 4], iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([9, 16], iterator.get_next().numpy()) @@ -257,7 +257,7 @@ class IteratorTest(test.TestCase): dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = trackable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) @@ -279,7 +279,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(3) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) @@ -293,7 +293,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index d18a097063c7d25947af3e2e2959ce574edd553f..3143270ccfe4f670428c80bdc1e09fa452584207 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: enable=g-bad-import-order @@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph = checkpointable_utils.object_metadata( + object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c8d9266672a8b87d32338ea7c4f74fb40d41c767..b32501c2e804838af9d4c77663be131b77bd30b4 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.Checkpointable): +class Metric(trackable.Trackable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -269,7 +269,7 @@ class Metric(checkpointable.Checkpointable): else: collections = [ops.GraphKeys.LOCAL_VARIABLES] collections += [ops.GraphKeys.METRIC_VARIABLES] - # Variables are Checkpointable dependencies of Metrics regardless of the + # Variables are Trackable dependencies of Metrics regardless of the # global/local distinction. Users can avoid saving variables by not adding a # dependency on the Metric. v = self._add_variable_with_custom_getter( @@ -282,7 +282,7 @@ class Metric(checkpointable.Checkpointable): use_resource=True, getter=variable_scope.get_variable, # Raise duplicate variable exceptions from get_variable rather than - # Checkpointable. + # Trackable. overwrite=True) self._vars.append(v) if context.executing_eagerly(): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 39e5957f5d1760613f2c33607c0bdb163040efb4..c56d1956fde35b562e60496015e666efe9ebc8f6 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class MetricsTest(test.TestCase): @@ -314,7 +314,7 @@ class MetricsTest(test.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") mean = metrics.Mean() - checkpoint = checkpointable_utils.Checkpoint(mean=mean) + checkpoint = trackable_utils.Checkpoint(mean=mean) mean.build() mean._built = True self.evaluate(mean.init_variables()) @@ -327,7 +327,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(200., self.evaluate(mean.value())) restore_mean = metrics.Mean() - restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean) status = restore_checkpoint.restore(save_path) restore_update = restore_mean(300.) status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 240f213c602395b8589d39c3ecd90b602ffa9848..b3e8daddaf2369e9e33179fde2aab1469e97ea47 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: disable=not-callable @@ -65,7 +65,7 @@ class NetworkTest(test.TestCase): def test_checkpointing_not_implemented(self): checkpoint_directory = self.get_temp_dir() - checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + checkpoint = trackable_utils.Checkpoint(net=MyNetwork()) with self.assertRaises(NotImplementedError): checkpoint.save(checkpoint_directory) diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 7803a6799bb64441fab881bf6ca986d5cf3851a8..258f0a19309235dcd99b31b4de3d35ef8d89b15b 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): @@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, checkpointable.CheckpointInitialValue): - self._maybe_initialize_checkpointable() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 12bbdc08cffe3bcb922e7d75c04566f7741fb7f5..df5b059448f735f7dc1f2963ffbc9c8a8287250a 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -137,8 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import Checkpoint +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 48a6ef4dca0ca7682f7b99b66177679f29ad9ec9..da2479a0b7b029561136903c82cabed9aae622b8 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -203,10 +203,7 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = [ - "nomac", # b/73741358 - "notsan", # b/67512932 - ], + tags = ["notsan"], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 8fc5f1cfe7800653ef1e43c6d40d1a66e34f2106..0a9199d61f36f10c98b95d79ece7e86765d2db0e 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,7 +14,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", - ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -72,60 +71,3 @@ tf_py_test( ], tags = ["no_pip"], ) - -py_library( - name = "sequence_feature_column_v2", - srcs = ["python/feature_column/sequence_feature_column_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -tf_py_test( - name = "sequence_feature_column_v2_test", - srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - additional_deps = [ - ":sequence_feature_column_v2", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/feature_column:feature_column_py", - "//tensorflow/python/feature_column:feature_column_v2_test", - ], - tags = ["no_pip"], -) - -py_test( - name = "sequence_feature_column_v2_integration_test", - srcs = ["python/feature_column/sequence_feature_column_v2_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sequence_feature_column_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/feature_column:feature_column_py", - "//tensorflow/python/keras:layers", - ], -) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 94fb35b3346ecd64cec5a89e495c7a2d1af3584b..063717f08aa88f4de9470d8392db2b7c95b3e4bf 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -127,6 +127,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', + 'is_nested', 'is_sequence', 'is_sequence_or_composite', 'flatten', diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index db0868fb2c43464a811b3d6dfcd96480ba2463ee..386e4cf69b7aa118a85fb25bcb809a879c5c1bd8 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -377,7 +377,10 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_pip", + "no_windows", + ], deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index d319aa7986d81cf9ac2d1dc2e15b053a0aa0c31b..92016e6a83975a9b15a39a15125e0eabc111912e 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -19,16 +19,25 @@ tf_cc_binary( "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:candidate_sampling_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:io_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:lookup_ops_op_lib", "//tensorflow/core:manip_ops_op_lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:parsing_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:random_ops_op_lib", "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:string_ops_op_lib", "//tensorflow/core:training_ops_op_lib", "//tensorflow/core:user_ops_op_lib", diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 403b522ce45ac6ad98a321378626b87aaa7738aa..9d9524e4e4b995d795b7c71b5bd083d11c60d5ce 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2308,7 +2308,7 @@ def layer_norm(inputs, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) - # Calculate the moments on the last axis (layer activations). + # By default, compute the moments across all the dimensions except the one with index 0. norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a28394964a12013c43d85701b5a0ab5c559afd62..8fda828e994bc2436eaba4475077020436703631 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -# TODO(rohanj): This should subclass Checkpointable and implement +# TODO(rohanj): This should subclass Trackable and implement # _gather_saveables_for_checkpoint. class ShardedMutableDenseHashTable(object): """A sharded version of MutableDenseHashTable. diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 591eabc66c49f301cf73cd912ebbef70cc9e1e3f..9fe8dafcc8edd6b80625c61a4a0e783e65b44720 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1483,3 +1483,4 @@ class IdTableWithHashBucketsTest(test.TestCase): if __name__ == "__main__": test.main() + diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 9ea94c74330e3e49414a6a84cd5bc0db3778114a..0a0ba36232075460b561bc54a95fc24973017571 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -40,7 +40,6 @@ tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/default/mutex.cc -tensorflow/core/platform/default/protobuf.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc @@ -53,6 +52,7 @@ tensorflow/core/platform/posix/error.cc tensorflow/core/platform/posix/load_library.cc tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc +tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc tensorflow/core/platform/setround.cc tensorflow/core/platform/tensor_coding.cc diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 63843b993c16363a80b64622af665aaa64e05830..93701249cc8bf722c8c8558e91e0b700ca1c4a04 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -10,6 +10,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -45,6 +46,28 @@ tf_gen_op_wrapper_py( deps = [":memory_stats_ops_op_lib"], ) +tf_gen_op_wrapper_cc( + name = "memory_stats_ops", + out_ops_file = "memory_stats_ops", +) + +cc_library( + name = "memory_stats_cc", + srcs = ["memory_stats_ops.cc"], + hdrs = ["memory_stats_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":memory_stats_kernels", + ":memory_stats_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "memory_stats_py", srcs = [ diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 974fb537499c5ea4591a0a128f53d2dea67b9e57..7ae1dbeaa2d04d7846e7fada117f3941319cc1c1 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -24,13 +24,15 @@ class MemoryStatsOp : public OpKernel { void Compute(OpKernelContext* context) override { Allocator* allocator = context->device()->GetAllocator(AllocatorAttributes()); - AllocatorStats allocator_stats; - allocator->GetStats(&allocator_stats); + absl::optional allocator_stats = allocator->GetStats(); + if (!allocator_stats) { + *allocator_stats = AllocatorStats(); + } Tensor* output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape({}), &output_tensor)); - output_tensor->scalar()() = ExtractAllocatorStats(allocator_stats); + output_tensor->scalar()() = ExtractAllocatorStats(*allocator_stats); } protected: @@ -71,7 +73,7 @@ class BytesLimitOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.bytes_limit; + return allocator_stats.bytes_limit ? *allocator_stats.bytes_limit : -1; } }; @@ -93,7 +95,7 @@ class MaxBytesInUseOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.max_bytes_in_use; + return allocator_stats.peak_bytes_in_use; } }; diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index b5de726a4cf833d23d668968b8080e7e484cd496..b2ea3daf82ed8daa6e0b9acd8e3cf258b8181615 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -44,15 +44,15 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import graph_view -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import graph_view +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() + super(NonLayerTrackable, self).__init__() self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -65,8 +65,8 @@ class MyModel(training.Model): super(MyModel, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -101,7 +101,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -117,10 +117,10 @@ class CheckpointingTests(test.TestCase): other_model(input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) named_variables, serialized_graph, _ = graph_view.ObjectGraphView( - root_checkpointable).serialize_object_graph() + root_trackable).serialize_object_graph() expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase): def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase): else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. - root_checkpointable.save_counter # pylint: disable=pointless-statement + root_trackable.save_counter # pylint: disable=pointless-statement self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - save_path = root_checkpointable.save(file_prefix=prefix) + save_path = root_trackable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration - status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status = root_trackable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly @@ -542,11 +542,11 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = util.Checkpoint( + first_root_trackable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate(util.gather_initializers( - first_root_checkpointable)) + first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate(optimizer.get_slot( @@ -558,23 +558,23 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = util.Checkpoint( + second_root_trackable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) - second_root_checkpointable.restore(None).initialize_or_restore() + second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([5.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(6.)) - save_path = second_root_checkpointable.save(checkpoint_prefix) + save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([8.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta_1_power)) - status = second_root_checkpointable.restore(save_path) + status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( @@ -594,7 +594,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): @test_util.run_in_graph_and_eager_modes - def test_checkpointable_save_restore(self): + def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( @@ -641,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable # with known values to check when loading. @@ -656,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase): var=model._named_dense.bias, name="m").assign([2.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(3.)) - return root_checkpointable + return root_trackable - def _set_sentinels(self, root_checkpointable): - self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + def _set_sentinels(self, root_trackable): + self.evaluate(root_trackable.model._named_dense.bias.assign([101.])) self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m") + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m") .assign([102.])) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(103.)) - def _check_sentinels(self, root_checkpointable): + def _check_sentinels(self, root_trackable): self.assertAllEqual( - [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + [1.], self.evaluate(root_trackable.model._named_dense.bias)) self.assertAllEqual([2.], self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m"))) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m"))) + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta_1_power)) def _write_name_based_checkpoint(self): @@ -698,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = util.CheckpointableSaver(graph_view.ObjectGraphView(root)) + object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root)) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index a49149e592f72b1977aae67078d7f41ca6f03d94..a7f978634ed45012144b2cc49ed069f6fca44f66 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -38,7 +38,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -223,7 +223,7 @@ class _OptimizerV2State(object): } self._slots = {} self._non_slot_dict = {} - # Extra state to help Optimizers implement Checkpointable. Holds information + # Extra state to help Optimizers implement Trackable. Holds information # about variables which will be restored as soon as they're created. self._deferred_dependencies = {} # Non-slot variables self._deferred_slot_restorations = {} # Slot variables @@ -366,8 +366,8 @@ class _OptimizerV2State(object): slot variable needs to be restored). Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. optional_op_name: Name to use when scoping the Variable that needs to be @@ -385,7 +385,7 @@ class _OptimizerV2State(object): # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = checkpointable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( var=variable, @@ -1259,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer): return self._per_graph_state.get(var._graph_key, None) # -------------- - # Overridden methods from Checkpointable. + # Overridden methods from Trackable. # -------------- - def _track_checkpointable(self, *args, **kwargs): + def _track_trackable(self, *args, **kwargs): """Optimizers may not track dependencies. Raises an error.""" raise NotImplementedError( "Optimizers may not have dependencies. File a feature request if this " @@ -1270,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _checkpoint_dependencies(self): - """From Checkpointable. Gather graph-specific non-slot variables to save.""" + """From Trackable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] state = self._get_per_graph_state() if state is not None: @@ -1279,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer): # Avoid comparing variables key=lambda item: item[0]): current_graph_non_slot_variables.append( - checkpointable.CheckpointableReference( + trackable.TrackableReference( name=name, ref=variable_object)) # Note: ignores super(); Optimizers may not have any dependencies outside of # state objects. return current_graph_non_slot_variables def _lookup_dependency(self, name): - """From Checkpointable. Find a non-slot variable in the current graph.""" + """From Trackable. Find a non-slot variable in the current graph.""" state = self._get_per_graph_state() if state is None: return None @@ -1295,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _deferred_dependencies(self): - """Lets Checkpointable know where non-slot variables are created. + """Lets Trackable know where non-slot variables are created. If necessary, creates a new state object for the current default graph. - Checkpointable will then add entries to that state's deferred dependency + Trackable will then add entries to that state's deferred dependency dictionary. The state object will check that dictionary when creating non-slot variables, restoring their value if an entry is found. @@ -1311,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer): def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, variable): - """Checkpointable: Restore a slot variable's value, possibly creating it. + """Trackable: Restore a slot variable's value, possibly creating it. Called when a variable which has an associated slot variable is created or restored. Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. """ diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 5e28e651c666b1c448f778fc9c02d637ce817bae..56f2a0acc9f2e6f951c5df26a53a31645697da4f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase): return (shape[1], shape[0]) + shape[2:] return shape - self.assertTrue( - isinstance(final_outputs, - beam_search_decoder.FinalBeamSearchDecoderOutput)) - self.assertTrue( - isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( @@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase): with_alignment_history=True) +@test_util.run_all_in_graph_and_eager_modes +class BeamSearchDecoderV2Test(test.TestCase): + + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) + batch_size = 5 + decoder_max_time = 4 + input_depth = 7 + cell_depth = 9 + attention_depth = 6 + vocab_size = 20 + end_token = vocab_size - 1 + start_token = 0 + embedding_dim = 50 + max_out = max(decoder_sequence_length) + output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) + beam_width = 3 + + with self.cached_session(): + batch_size_tensor = constant_op.constant(batch_size) + embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) + cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) + coverage_penalty_weight = 0.0 + if has_attention: + coverage_penalty_weight = 0.2 + inputs = array_ops.placeholder_with_default( + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), + shape=(None, None, input_depth)) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) + attention_mechanism = attention_wrapper.BahdanauAttention( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) + cell = attention_wrapper.AttentionWrapper( + cell=cell, + attention_mechanism=attention_mechanism, + attention_layer_size=attention_depth, + alignment_history=with_alignment_history) + cell_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone(cell_state=initial_state) + bsd = beam_search_decoder.BeamSearchDecoderV2( + cell=cell, + beam_width=beam_width, + output_layer=output_layer, + length_penalty_weight=0.0, + coverage_penalty_weight=coverage_penalty_weight, + output_time_major=time_major, + maximum_iterations=max_out) + + final_outputs, final_state, final_sequence_lengths = bsd( + embedding, + start_tokens=array_ops.fill([batch_size_tensor], start_token), + end_token=end_token, + initial_state=cell_state) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) + + beam_search_decoder_output = final_outputs.beam_search_decoder_output + expected_seq_length = 3 if context.executing_eagerly() else None + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(beam_search_decoder_output.scores.get_shape().as_list())) + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(final_outputs.predicted_ids.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + eval_results = self.evaluate({ + 'final_outputs': final_outputs, + 'final_sequence_lengths': final_sequence_lengths + }) + + max_sequence_length = np.max(eval_results['final_sequence_lengths']) + + # A smoke test + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), + eval_results['final_outputs'].beam_search_decoder_output.scores.shape) + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), eval_results[ + 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) + + def testDynamicDecodeRNNBatchMajorNoAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=False) + + def testDynamicDecodeRNNBatchMajorYesAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=True) + + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 5bcf0af8897ba8bc868951d03a18081e24a00f35..79c2ac2f500307ba23b6d97a7a30c6d04cea5176 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,6 +25,7 @@ import math import numpy as np from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -1919,7 +1920,15 @@ class AttentionWrapperState( def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): - return tensor_util.with_same_shape(old, new) + if not context.executing_eagerly(): + return tensor_util.with_same_shape(old, new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError("The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % + (old.shape, new.shape)) + return new return new return nest.map_structure( @@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. - context = math_ops.matmul(expanded_alignments, attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context_ = array_ops.squeeze(context_, [1]) if attention_layer is not None: - attention = attention_layer(array_ops.concat([cell_output, context], 1)) + attention = attention_layer(array_ops.concat([cell_output, context_], 1)) else: - attention = context + attention = context_ return attention, alignments, next_attention_state diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 8f8f057702951094758b277ce060955f3dc6e99d..1d773a449890cd7335b2225db39d79ca958a3276 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -24,11 +24,12 @@ import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.layers import base as layers_base +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length): return ordered -def _check_maybe(t): +def _check_ndims(t): if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) + def _check_static_batch_beam_maybe(shape, batch_size, beam_width): """Raises an exception if dimensions are known statically and can not be reshaped to [batch_size, beam_size, -1]. @@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width): return False return True + def _check_batch_beam(t, batch_size, beam_width): """Returns an Assert operation checking that the elements of the stacked TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, @@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width): return control_flow_ops.Assert(condition, [error_message]) +class BeamSearchDecoderMixin(object): + """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. -class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder. - - **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in - `AttentionWrapper`, then you must ensure that: - - - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - - The `batch_size` argument passed to the `zero_state` method of this - wrapper is equal to `true_batch_size * beam_width`. - - The initial state created with `zero_state` above contains a - `cell_state` value containing properly tiled final state from the - encoder. - - An example: - - ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( - encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( - encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( - sequence_length, multiplier=beam_width) - attention_mechanism = MyFavoriteAttentionMechanism( - num_units=attention_depth, - memory=tiled_inputs, - memory_sequence_length=tiled_sequence_length) - attention_cell = AttentionWrapper(cell, attention_mechanism, ...) - decoder_initial_state = attention_cell.zero_state( - dtype, batch_size=true_batch_size * beam_width) - decoder_initial_state = decoder_initial_state.clone( - cell_state=tiled_encoder_final_state) - ``` - - Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use - when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages - the translation to cover all inputs. + It is expected to be used a base class for concrete BeamSearchDecoder. Since + this is a mixin class, it is expected to be used together with other class as + base. """ def __init__(self, cell, - embedding, - start_tokens, - end_token, - initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, - reorder_tensor_arrays=True): - """Initialize the BeamSearchDecoder. + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. @@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder): Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. + **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, - or `output_layer` is not an instance of `tf.layers.Layer`. - ValueError: If `start_tokens` is not a vector or - `end_token` is not a scalar. + or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and - not isinstance(output_layer, layers_base.Layer)): + not isinstance(output_layer, layers.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - - self._batch_size = array_ops.size(start_tokens) + self._start_tokens = None + self._end_token = None + self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight - self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, initial_state, self._cell.state_size) - self._start_tokens = array_ops.tile( - array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) - self._start_inputs = self._embedding_fn(self._start_tokens) - - self._finished = array_ops.one_hot( - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=False, - off_value=True, - dtype=dtypes.bool) + super(BeamSearchDecoderMixin, self).__init__(**kwargs) @property def batch_size(self): return self._batch_size def _rnn_output_size(self): + """Get the output shape from the RNN layer.""" size = self._cell.output_size if self._output_layer is None: return size @@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder): predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_cell_state)[0].dtype - return BeamSearchDecoderOutput( - scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), - predicted_ids=dtypes.int32, - parent_ids=dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, start_inputs, initial_state)`. - """ - finished, start_inputs = self._finished, self._start_inputs - - dtype = nest.flatten(self._initial_cell_state)[0].dtype - log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=ops.convert_to_tensor(0.0, dtype=dtype), - off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), - dtype=dtype) - init_attention_probs = get_attention_probs( - self._initial_cell_state, self._coverage_penalty_weight) - if init_attention_probs is None: - init_attention_probs = () - - initial_state = BeamSearchDecoderState( - cell_state=self._initial_cell_state, - log_probs=log_probs, - finished=finished, - lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64), - accumulated_attention_probs=init_attention_probs) - - return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. @@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: @@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: @@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(t, tensor_array_ops.TensorArray): return t # pylint: disable=protected-access - if (not t._infer_shape or not t._element_shape - or t._element_shape[0].ndims is None - or t._element_shape[0].ndims < 1): + # This is a bad hack due to the implementation detail of eager/graph TA. + # TODO(b/124374427): Update this to use public property of TensorArray. + if context.executing_eagerly(): + element_shape = t._element_shape + else: + element_shape = t._element_shape[0] + if (not t._infer_shape + or not t._element_shape + or element_shape.ndims is None + or element_shape.ndims < 1): shape = ( - t._element_shape[0] if t._infer_shape and t._element_shape + element_shape if t._infer_shape and t._element_shape else tensor_shape.TensorShape(None)) tf_logging.warn("The TensorArray %s in the cell state is not amenable to " "sorting based on the beam search result. For a " @@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder): "defined and have at least a rank of 1, but saw shape: %s" % (t.handle.name, shape)) return t - shape = t._element_shape[0] # pylint: enable=protected-access if not _check_static_batch_beam_maybe( - shape, tensor_util.constant_value(self._batch_size), self._beam_width): + element_shape, tensor_util.constant_value(self._batch_size), + self._beam_width): return t t = t.stack() with ops.control_dependencies( @@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder): return (beam_search_output, beam_search_state, next_inputs, finished) +class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoder to cover all inputs. + """ + + def __init__(self, + cell, + embedding, + start_tokens, + end_token, + initial_state, + beam_width, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True): + """Initialize the BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + super(BeamSearchDecoder, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays) + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, start_inputs, initial_state)`. + """ + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + +class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoding to cover all inputs. + """ + + def __init__(self, + cell, + beam_width, + embedding_fn=None, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderV2. + + Args: + cell: An `RNNCell` instance. + beam_width: Python integer, the number of beams. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + **kwargs: Dict, other keyword arguments for initialization. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + """ + super(BeamSearchDecoderV2, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays, + **kwargs) + + if embedding_fn is None or callable(embedding_fn): + self._embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + + def initialize(self, + embedding, + start_tokens, + end_token, + initial_state): + """Initialize the decoder. + + Args: + embedding: A tensor from the embedding layer output, which is the + `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + Returns: + `(finished, start_inputs, initial_state)`. + Raises: + ValueError: If `start_tokens` is not a vector or `end_token` is not a + scalar. + """ + if embedding is not None and self._embedding_fn is not None: + raise ValueError( + "embedding and embedding_fn cannot be provided at same time") + elif embedding is not None: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): + init_kwargs = kwargs + init_kwargs["start_tokens"] = start_tokens + init_kwargs["end_token"] = end_token + init_kwargs["initial_state"] = initial_state + return decoder.dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=embeddning, + decoder_init_kwargs=init_kwargs) + + def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): @@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, """ if isinstance(gather_from, tensor_array_ops.TensorArray): return gather_from - _check_maybe(gather_from) + _check_ndims(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices, diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 290c16fe3966791ea78986539750caf938a37322..40bf7081a3f22dfd68fd46f0f61695ee9ca7863b 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _model_ops = loader.load_op_library( diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 9184198cd4c8fd2a7609714d094d5ef2b6868658..80afcfb251f4d6455a9eb8ba5df4a6e43d2feb1c 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _stats_ops = loader.load_op_library( diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index d1be31ddc799ce4c4ef9baa15729fde7925f2f6c..4ba814b9e3d3621f9ab924961e2740885fa93b33 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -161,7 +161,10 @@ py_test( ], shard_count = 10, srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], # b/63391119 + tags = [ + "no_pip_gpu", # b/63391119 + "notap", # b/124520733 + ], deps = [ ":estimators", ":feature_keys", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 294dbddcb5ee1b1758182c10e2816f353d989084..9665604a52ba5da427b3a27415e58b4d6c9b93a1 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -23,17 +23,13 @@ package( ], ) -cc_library( - name = "all_ops", +py_library( + name = "tpu_py", + srcs = ["python/ops/tpu_ops.py"], + srcs_version = "PY2AND3", deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:tpu_ops_gen", ], ) @@ -75,7 +71,6 @@ py_library( ":functional", ":tpu_embedding", ":tpu_lib", - ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -98,122 +93,15 @@ py_library( ], ) -tf_gen_op_libs( - op_lib_names = [ - "cross_replica_ops", - "heartbeat_ops", - "host_compute_ops", - "infeed_ops", - "outfeed_ops", - "replication_ops", - "tpu_configuration_ops", - "tpu_embedding_ops", - "tpu_ordinal_selector_op", - "functional_ops", - ], - deps = [ - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ops.so", - srcs = [ - "ops/cross_replica_ops.cc", - "ops/heartbeat_ops.cc", - "ops/host_compute_ops.cc", - "ops/infeed_ops.cc", - "ops/outfeed_ops.cc", - "ops/replication_ops.cc", - "ops/tpu_configuration_ops.cc", - "ops/tpu_embedding_ops.cc", - ], - deps = [ - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ops", - hidden = [ - "SendTPUEmbeddingGradients", - "EnqueueTPUEmbeddingIntegerBatch", - "EnqueueTPUEmbeddingSparseBatch", - "EnqueueTPUEmbeddingSparseTensorBatch", - ], - deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ordinal_selector_op.so", - srcs = ["ops/tpu_ordinal_selector_op.cc"], -) - -tf_custom_op_py_library( - name = "tpu_ordinal_selector_py", - srcs = ["python/ops/tpu_ordinal_selector_op.py"], - dso = [":python/ops/_tpu_ordinal_selector_op.so"], - kernels = [ - ":tpu_ordinal_selector_op_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":tpu_ordinal_selector_op", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ordinal_selector_op", - deps = [ - ":tpu_ordinal_selector_op_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_functional_ops.so", - srcs = ["ops/functional_ops.cc"], -) - -tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "python/tpu/gen_functional_ops.py", - hidden = [ - "TPUPartitionedCall", - ], - deps = [":functional_ops_op_lib"], -) - -tf_custom_op_py_library( +py_library( name = "functional", srcs = ["python/tpu/functional.py"], - dso = [":python/ops/_functional_ops.so"], - kernels = [ - ":functional_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = [ "//visibility:public", ], deps = [ - ":gen_functional_ops", + "//tensorflow/python:tpu_ops_gen", ], ) @@ -223,28 +111,8 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc", - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_proto_py", "//tensorflow/contrib/tpu/profiler:trace_events_proto_py", - "//tensorflow/python:util", - ], -) - -tf_custom_op_py_library( - name = "tpu_py", - srcs = ["python/ops/tpu_ops.py"], - dso = [":python/ops/_tpu_ops.so"], - kernels = [ - ":all_ops", - ], - srcs_version = "PY2AND3", - deps = [ - ":profiler", - ":tpu_ops", - "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", + "//tensorflow/core/profiler:profiler_analysis_proto_py", "//tensorflow/python:util", ], ) @@ -327,7 +195,6 @@ py_library( ":datasets", ":functional", ":profiler", - ":tpu_ordinal_selector_py", ":tpu_py", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", @@ -347,6 +214,7 @@ py_library( "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tpu_ops_gen", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", @@ -466,17 +334,20 @@ tf_py_test( py_library( name = "tpu_embedding", - srcs = ["python/tpu/tpu_embedding.py"], + srcs = [ + "python/tpu/tpu_embedding.py", + "python/tpu/tpu_embedding_gradient.py", + ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_ops", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", + "//tensorflow/python:tpu_ops_gen", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "@six_archive//:six", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 7ad30c61e42cefa70d660d265ccc117c6ff00d87..2a8aeea317478d85cb9c236848eb66a5d73781bf 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -4,17 +4,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_profiler_all_protos") - -tf_proto_library( - name = "tpu_profiler_proto", - srcs = ["tpu_profiler.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = tf_profiler_all_protos() + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) cc_library( name = "dump_tpu_profile", @@ -22,10 +11,10 @@ cc_library( hdrs = ["dump_tpu_profile.h"], visibility = ["//visibility:public"], deps = [ - ":tpu_profiler_proto_cc", ":trace_events_proto_cc", ":trace_events_to_json", "//tensorflow/core:framework", + "//tensorflow/core:grpc_services", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:protos_all_cc", @@ -82,20 +71,10 @@ tf_cc_test( ], ) -tf_proto_library( - name = "tpu_profiler_analysis_proto", - srcs = ["tpu_profiler_analysis.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":tpu_profiler_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - py_library( name = "tpu_profiler_analysis_pb2_grpc", srcs = ["tpu_profiler_analysis_pb2_grpc.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [":tpu_profiler_analysis_proto_py"], + deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"], ) diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h index ecf21b1de2219e8896d5e8b79325a193de0b0fa1..7ddd7b1c9be945ba45b945f7b822d90d5a3b4cbc 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ -#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/grpc_services.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 55f7c6bcbc11b3a11bb3372aa4f26d3c8a87ff3c..ec0d5fec44e1687c20946c700769efe5b818af68 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -24,20 +24,15 @@ import platform from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops import gen_tpu_ops - from tensorflow.contrib.tpu.ops.gen_tpu_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader + from tensorflow.python.ops import gen_tpu_ops + from tensorflow.python.ops.gen_tpu_ops import * # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _tpu_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ops.so")) - def _create_default_group_assignment(): num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -160,6 +155,36 @@ if platform.system() != "Windows": dtypes.complex64, dtypes.uint32 ]) + @ops.RegisterGradient("TPUEmbeddingActivations") + def _embedding_activations_grad(activations_op, grad_wrt_activations): + """Saves the gradient of embedding activations ops in a graph collection.""" + g = ops.get_default_graph() + table_id = activations_op.get_attr("table_id") + lookup_id = activations_op.get_attr("lookup_id") + table_gradients = g.get_collection_ref( + "tpu_embedding_gradients_table_%d" % table_id) + + if not table_gradients: + raise RuntimeError( + "Gradients for TPUEmbedding have been generated in non-training mode." + "This is not expected. Consider putting your Optimizer.minimize code " + "behind the training mode condition check. For Estimator, you can " + "do \n\n" + " if mode == tf.estimator.ModeKeys.TRAIN:\n" + " train_op = opt.minimize(loss)\n" + "\n") + + table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) + return [ + # RegisterGradient requires that value be returned for all inputs. Since + # the first argument (tpu_gradient_variable_{table_name}) has shape [1], + # we will return zeros(shape=[1]). The actual gradient w.r.t. the + # embedding activations (grad_wrt_activations) has the same shape as the + # activations returned by embedding_activations. + array_ops.zeros(arg.shape, dtype=dtypes.float32) + for arg in activations_op.inputs + ] + def infeed_dequeue(dtype, shape, name=None): """A placeholder op for a value that will be fed into the computation. @@ -237,12 +262,11 @@ if platform.system() != "Windows": """ if learning_rates is None: learning_rates = [] - return gen_tpu_ops._send_tpu_embedding_gradients( + return gen_tpu_ops.send_tpu_embedding_gradients( inputs=inputs, learning_rates=learning_rates, config=config, name=name) - send_tpu_embedding_gradients.__doc__ = ( - gen_tpu_ops._send_tpu_embedding_gradients.__doc__) + gen_tpu_ops.send_tpu_embedding_gradients.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_integer_batch(batch, @@ -268,14 +292,14 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_integer_batch( + return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( batch=batch, device_ordinal=device_ordinal, mode_override=mode_override, name=name) enqueue_tpu_embedding_integer_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_sparse_batch(sample_indices, @@ -317,7 +341,7 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch( + return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( sample_indices=sample_indices, embedding_indices=embedding_indices, aggregation_weights=aggregation_weights, @@ -327,7 +351,7 @@ if platform.system() != "Windows": name=name) enqueue_tpu_embedding_sparse_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, @@ -375,7 +399,7 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch( + return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( sample_indices=sample_indices, embedding_indices=embedding_indices, aggregation_weights=aggregation_weights, @@ -386,7 +410,7 @@ if platform.system() != "Windows": name=name) enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) else: # We have already built the appropriate libraries into the binary via CMake diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py index 5ca38cd1bae5753a7398834bd96d3b26e66b4941..6917ac2e1a769378c77dcdcd0d63da2028a3a34c 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -23,15 +23,12 @@ import platform if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * + from tensorflow.python.ops.gen_tpu_ops import tpu_ordinal_selector from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _tpu_ordinal_selector_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ordinal_selector_op.so")) - else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py index 15ce6aceec299adacd7025f0021cf8b6f6ef765b..7e64448348462ad1001d5d8826c8c7b3c6e636e8 100644 --- a/tensorflow/contrib/tpu/python/profiler/__init__.py +++ b/tensorflow/contrib/tpu/python/profiler/__init__.py @@ -20,8 +20,8 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.tpu.profiler.tpu_profiler_analysis_pb2 import * from tensorflow.contrib.tpu.profiler.trace_events_pb2 import * +from tensorflow.core.profiler.profiler_analysis_pb2 import * # pylint: enable=wildcard-import,unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py index dd239d5d78fbdc012566398b3a5bec89eeaf4ed2..98aa7827fcf38b10e97318067ffa99008e93c557 100644 --- a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -286,6 +286,7 @@ class EmbeddingConfig(object): self._optimization_parameters = _get_tpu_embedding_optimization_parameters( self._embedding_config_spec) self._mode_to_tpu_embedding_dict = {} + self.dummy_table_variables = None def has_embedding_tables(self): return bool(self._table_to_config_dict) diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py index 24c85156e53a9b770f811c4cf3b903eab6553c76..3d04c64033b5a27b34b5aa77a8753246d35d23aa 100644 --- a/tensorflow/contrib/tpu/python/tpu/functional.py +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -18,22 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform +from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import gen_functional_ops - - -TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access - - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_partitioned_call_op = loader.load_op_library( - resource_loader.get_path_to_datafile("../ops/_functional_ops.so") - ) +TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 43b9168eccbec4cd8ce874beff7b0f1d8e09e812..ae0582208450919b79a7c3031c726e24986aa456 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -237,7 +237,8 @@ class TensorTracer(object): (2) which Ops to be traced (via op.name or op.type) (3) output trace file path. """ - + # The set of graphs that are rewritten by tensor tracer. + _traced_graphs = set() @staticmethod def _match_next_flag(flags, pos): """Returns the match for the next TensorTracer flag. @@ -1559,6 +1560,12 @@ class TensorTracer(object): RuntimeError: If tensor_fetches is None or empty. """ + if graph in TensorTracer._traced_graphs: + logging.warning('Graph is already rewritten with tensor tracer, ignoring ' + 'multiple calls.') + return tensor_fetches + else: + TensorTracer._traced_graphs.add(graph) self._device_type = _DEVICE_TYPE_TPU self._num_replicas = num_replicas self._num_replicas_per_host = num_replicas_per_host @@ -1604,6 +1611,14 @@ class TensorTracer(object): Raises: RuntimeError: If tensor_fetches is None or empty. """ + + if graph in TensorTracer._traced_graphs: + logging.warning('Graph is already rewritten with tensor tracer, ignoring ' + 'multiple calls.') + return tensor_fetches + else: + TensorTracer._traced_graphs.add(graph) + self._device_type = _DEVICE_TYPE_CPU self._num_replicas = 1 self._num_replicas_per_host = 1 diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 673129b4bef8a7470192a5d7650a858257f653bb..3b2d0534773fa0cce3c515cfaa7102cec195fcc3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -810,6 +810,9 @@ def split_compile_and_replicate(computation, serialized_padding_maps.append(padding_map.SerializeToString()) metadata_kwargs["padding_map"] = serialized_padding_maps + metadata_kwargs["step_marker_location"] = getattr( + computation, "step_marker_location", "STEP_MARK_AT_ENTRY") + graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. @@ -903,6 +906,17 @@ def split_compile_and_replicate(computation, else: output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) + # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid + # import-cycle + # pylint: disable=g-import-not-at-top + from tensorflow.contrib.tpu.python.tpu import tensor_tracer + # pylint: enable=g-import-not-at-top + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + output_tensors = tt.trace_tpu(ops.get_default_graph(), + output_tensors, control_deps, + num_replicas) + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index eb99a18d83987b098dcae9d58d9af14deebc4f56..1ba8017cda834436cbcc72f03a1f8b88295bf80c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -25,7 +25,6 @@ import re import six from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 @@ -40,7 +39,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables TRAINING = elc.TPUEmbeddingConfiguration.TRAINING INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE @@ -268,10 +266,11 @@ class TPUEmbedding(object): base_optimizer) train_op = cross_shard_optimizer.minimize(loss) - # `train_op` and `send_gradients_op` must happen in order. - with ops.control_dependencies([train_op]): - send_gradients_op = embedding.generate_send_gradients_op() - with ops.control_dependencies([send_gradients_op]): + gradients = ( + tpu_embedding_gradient.get_gradients_through_compute_gradients( + cross_shard_optimizer, loss, activations) + send_gradients_op = embedding.generate_send_gradients_op(gradients) + with ops.control_dependencies([train_op, send_gradients_op]): loss = array_ops.identity(loss) loss = tpu.shard(computation, @@ -281,7 +280,6 @@ class TPUEmbedding(object): sess.run(tpu.initialize_system(embedding_config= embedding.config_proto)) sess.run(variables.global_variables_initializer()) - sess.run(embedding.init_ops) sess.run(embedding_variables_and_ops.load_ops()) sess.run(enqueue_ops) loss_val = sess.run(loss) @@ -360,8 +358,6 @@ class TPUEmbedding(object): _validate_batch_size(self._batch_size, self._num_cores) self._batch_size_per_core = self._batch_size // self._num_cores - self._init_ops = [] - # TODO(shizhiw): remove `mode`? if mode == TRAINING: _validate_optimization_parameters(optimization_parameters) @@ -384,9 +380,6 @@ class TPUEmbedding(object): self._optimizer_handler = _get_optimization_handler( self._optimization_parameters) - dummy_table_variables_init_op = self._create_dummy_table_variables() - self._init_ops.append(dummy_table_variables_init_op) - self._config_proto = self._create_config_proto() @property @@ -441,19 +434,6 @@ class TPUEmbedding(object): """ return self._config_proto - @property - def init_ops(self): - """Initialization ops for TPU embedding. - - It must be called after all global variables have been initialized, - i.e. after `global_variables_initializer()`, as it loads embedding - tables into TPU. - - Returns: - A list of ops. - """ - return self._init_ops - @property def table_to_config_dict(self): return copy.copy(self._table_to_config_dict) @@ -462,6 +442,10 @@ class TPUEmbedding(object): def feature_to_table_dict(self): return copy.copy(self._feature_to_table_dict) + @property + def table_to_features_dict(self): + return copy.copy(self._table_to_features_dict) + @property def optimization_parameters(self): return self._optimization_parameters @@ -584,51 +568,6 @@ class TPUEmbedding(object): slot_variables_by_table, load_ops, retrieve_ops) - def _create_dummy_table_variables(self): - """Create dummy embedding table variables. - - The sole purpose of these dummy variables are to trigger gradient - calcuation wrt them so that the gradients wrt activation can be captured - and later sent to TPU embedding. - - Returns: - Initializer for these variables. - - Raises: - RuntimeError: if collection to store gradients already exists and is not - empty. - """ - self._dummy_table_variables = [] - # TODO(shizhiw): remove table id. - for table_id, table in enumerate(self._table_to_features_dict): - self._dummy_table_variables.append( - variable_scope.get_variable( - 'tpu_embedding_dummy_table_variable_%s' % table, - dtype=dtypes.float32, - shape=[1], - use_resource=True, - trainable=True, - # TODO(shizhiw): Remove these dummy variables as - # tensorflow optimizer creates slot variable for them which - # is undesirable. - # e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}. - # Explicitly specifying collections prevents this variable from - # being added to the GLOBAL_VARIABLES collection, so that Saver() - # ignores it. - collections=['tpu_embedding_dummy_table_variables'])) - - g = ops.get_default_graph() - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - if table_gradients: - raise RuntimeError( - 'tpu_embedding_gradients_table_%d is not empty.' % table_id) - table_gradients.extend([None] * len(self._table_to_features_dict[table])) - - return variables.variables_initializer( - self._dummy_table_variables, - name='tpu_embedding_dummy_table_variables_init') - def generate_enqueue_ops(self, sparse_features_list): """Generate enqueue ops. @@ -775,52 +714,34 @@ class TPUEmbedding(object): for lookup_id, feature in enumerate(features): start_row = lookup_id * self._batch_size_per_core end_row = start_row + self._batch_size_per_core - activations[feature] = gen_tpu_ops.tpu_embedding_activations( - self._dummy_table_variables[table_id], - recv_activations[table_id][start_row:end_row, :], - table_id=table_id, - lookup_id=lookup_id) + activations[feature] = recv_activations[table_id][start_row:end_row, :] return activations - # TODO(shizhiw): Make `gradient_multiplier` per feature. Setting it to 0 would - # have the effect of `tf.stop_gradients()`. - # TODO(shizhiw): Consider alternative ways to capture gradients wrt embedding - # layer outputs to remove `_dummy_table_variables`, - # `_embedding_activation_grad` and `tpu_embedding_gradients_table_%d'. - def generate_send_gradients_op(self, gradient_multipliers=None): - """Retrieve gradients from collections and send them to TPU embedding. + def generate_send_gradients_op(self, feature_to_gradient_dict): + """Send gradient to TPU embedding. Args: - gradient_multipliers: None, or dict mapping table names to gradient - multiplier Tensors. + feature_to_gradient_dict: dict mapping feature names to gradient wrt + activations. Returns: SendTPUEmbeddingGradients Op. Raises: - ValueError: If required gradients have not been defined. RuntimeError: If `mode` is not `TRAINING`. """ if self._mode != TRAINING: raise RuntimeError('Only in training mode gradients need to ' 'be sent to TPU embedding; got mode {}.' .format(self._mode)) - - g = ops.get_default_graph() - gradients = list() - for table_id, table in enumerate(self._table_to_config_dict): - table_gradients = g.get_collection( - 'tpu_embedding_gradients_table_%d' % table_id) - if any(gradient is None for gradient in table_gradients): - raise ValueError( - 'Table {}/{} has undefined gradients: this is probably because the ' - 'model asked TPUEmbedding to compute activations that were not ' - 'used.'.format(table_id, table)) + gradients = [] + for table in self._table_to_features_dict: + features = self._table_to_features_dict[table] + table_gradients = [ + feature_to_gradient_dict[feature] for feature in features + ] concat_table_grads = array_ops.concat(table_gradients, axis=0) - if gradient_multipliers is not None: - concat_table_grads *= gradient_multipliers[table.name] gradients.append(concat_table_grads) - return tpu_ops.send_tpu_embedding_gradients( inputs=gradients, config=self.config_proto.SerializeToString()) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..dace0d801b3a91caae9cafea59366f4adc9325a7 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Optional helper for gradient handling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables + + +def get_gradients_through_compute_gradients(optimizer, loss, activations): + """Compute gradients to send to TPU embedding. + + Args: + optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer. + Used to call compute_gradients(). + loss: a Tensor to call optimizer.compute_gradients() on. + activations: an OrderedDict mapping feature_name to Tensors of activations. + + Returns: + An OrderedDict mapping from feature name Strings to Tensors of gradients of + the loss wrt the activations of the features. + """ + activation_list = activations.values() + grads_and_vars = optimizer.compute_gradients(loss, activation_list) + grads = [grad for grad, _ in grads_and_vars] + feature_to_gradient_dict = collections.OrderedDict( + zip(activations.keys(), grads)) + return feature_to_gradient_dict + + +def create_dummy_table_variables(tpu_embedding): + """Create dummy embedding table variables. + + The sole purpose of these dummy variables are to trigger gradient + calcuation wrt them so that the gradients wrt activation can be captured + and later sent to TPU embedding. + + Args: + tpu_embedding: TPUEmbedding, dummy table variables will be created for use + with tpu_embedding. + + Returns: + A tuple of dummy variables and their initializer. + + Raises: + RuntimeError: if collection to store gradients already exists and is not + empty. + """ + dummy_table_variables = collections.OrderedDict() + for table_id, table in enumerate(tpu_embedding.table_to_features_dict): + dummy_table_variables[table] = ( + # Explicitly specifying collections prevents this variable from + # being added to the GLOBAL_VARIABLES collection, so that Saver() + # ignores it. + # But Tensorflow optimizer creates slot variable for these dummy + # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}, + # which will be in GLOBAL_VARIABLES collection, + variable_scope.get_variable( + 'tpu_embedding_dummy_table_variable_{}'.format(table), + dtype=dtypes.float32, + shape=[1], + use_resource=True, + trainable=True, + collections=['tpu_embedding_dummy_table_variables'])) + + g = ops.get_default_graph() + table_gradients = g.get_collection_ref( + 'tpu_embedding_gradients_table_{}'.format(table_id)) + if table_gradients: + raise RuntimeError( + 'tpu_embedding_gradients_table_{} is not empty.'.format(table_id)) + table_gradients.extend( + [None] * len(tpu_embedding.table_to_features_dict[table])) + + return (dummy_table_variables, + variables.variables_initializer( + dummy_table_variables.values(), + name='tpu_embedding_dummy_table_variables_init')) + + +def hook_dummy_table_variables_to_activations(tpu_embedding, activations, + dummy_table_variables): + """Have activations depend on dummy table variables for gradient intercept. + + Args: + tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from + tpu_embedding. + activations: An OrderedDict of feature name String to activation tensors. + dummy_table_variables: An OrderedDict of table name String to dummy table + variables. + + Returns: + An OrderedDict of feature name String to activation tensors, which can be + used just as the activations input. + """ + new_activations = collections.OrderedDict() + for feature in activations: + table = tpu_embedding.feature_to_table_dict[feature] + new_activations[feature] = tpu_ops.tpu_embedding_activations( + dummy_table_variables[table], + activations[feature], + table_id=tpu_embedding.table_to_config_dict.keys().index(table), + lookup_id=tpu_embedding.table_to_features_dict[table].index(feature)) + return new_activations + + +def get_gradients_through_dummy_table_variables(tpu_embedding): + """Get gradients wrt the activations of each feature. + + Args: + tpu_embedding: TPUEmbedding, create dummy table variable to be used with + tpu_embedding. + + Returns: + An OrderedDict mapping feature name to gradient. + + Raises: + ValueError: if some gradients are not defined. + """ + g = ops.get_default_graph() + feature_to_gradient_dict = collections.OrderedDict() + for table_id, table in enumerate(tpu_embedding.table_to_config_dict): + table_gradients = g.get_collection( + 'tpu_embedding_gradients_table_{}'.format(table_id)) + if any(gradient is None for gradient in table_gradients): + raise ValueError( + 'Table {} with id {} has undefined gradients: this is probably ' + 'because the model asked TPUEmbedding to compute activations that ' + 'were not used.'.format(table, table_id)) + for feature, gradient in zip(tpu_embedding.table_to_features_dict[table], + table_gradients): + feature_to_gradient_dict[feature] = gradient + return feature_to_gradient_dict diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 4f761e3599bfbd3a9429c8d456ae0b368229904f..b2019c1083653f9af2c273cdf24ba5b0364bf478 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -32,7 +32,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.ops import tpu_ordinal_selector_op from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional @@ -41,7 +40,9 @@ from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context +from tensorflow.contrib.tpu.python.tpu import tpu_embedding_gradient from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import @@ -1364,13 +1365,13 @@ def call_computation(computation, # TPU core with every `Session.run()` call. Note that the entire inference # graph executes on a single core, and that invocations of this graph # will round-robin among the cores attached to a host. - @function.Defun() + @function.Defun(capture_resource_var_by_value=False) def tpu_subgraph(): return computation() return tpu_functional.TPUPartitionedCall( args=tpu_subgraph.captured_inputs, - device_ordinal=tpu_ordinal_selector_op.tpu_ordinal_selector(), + device_ordinal=tpu_ops.tpu_ordinal_selector(), Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], f=tpu_subgraph) else: @@ -1396,11 +1397,19 @@ class _ModelFnWrapper(object): def call_without_tpu(self, features, labels, is_export_mode): return self._call_model_fn(features, labels, is_export_mode=is_export_mode) - def _add_embedding_features(self, features): + def _add_embedding_features(self, features, hook_dummy_table_variables): + """Add embedding features, optionally add hook to intercept gradient.""" if self._ctx.embedding_config: tpu_embedding_ = self._ctx.embedding_config.tpu_embedding embedding_activations = tpu_embedding_.get_activations() - features.update(embedding_activations) + if hook_dummy_table_variables: + new_embedding_activations = ( + tpu_embedding_gradient.hook_dummy_table_variables_to_activations( + tpu_embedding_, embedding_activations, + self._ctx.embedding_config.dummy_table_variables)) + features.update(new_embedding_activations) + else: + features.update(embedding_activations) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -1434,7 +1443,7 @@ class _ModelFnWrapper(object): del loss # unused; required in function signature. inputs = dequeue_fn() features, labels = inputs.features_and_labels() - self._add_embedding_features(features) + self._add_embedding_features(features, True) estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels)) @@ -1447,19 +1456,17 @@ class _ModelFnWrapper(object): captured_training_hooks.capture(estimator_spec.training_hooks) - if tensor_tracer.TensorTracer.is_enabled(): - tt = tensor_tracer.TensorTracer() - loss = tt.trace_tpu(ops.get_default_graph(), - loss, train_op, - self._ctx.num_replicas, - self._ctx.num_of_replicas_per_host, - self._ctx.num_hosts) - if self._ctx.embedding_config is None: apply_sparse_grads = [] else: tpu_embedding_ = self._ctx.embedding_config.tpu_embedding - apply_sparse_grads = [tpu_embedding_.generate_send_gradients_op()] + gradients = ( + tpu_embedding_gradient.get_gradients_through_dummy_table_variables( + tpu_embedding_) + ) + apply_sparse_grads = [ + tpu_embedding_.generate_send_gradients_op(gradients) + ] # We must run train_op to update the variables prior to running the # outfeed. @@ -1509,7 +1516,7 @@ class _ModelFnWrapper(object): """Evaluation step function for use inside a while loop.""" inputs = dequeue_fn() features, labels = inputs.features_and_labels() - self._add_embedding_features(features) + self._add_embedding_features(features, False) tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access @@ -2465,8 +2472,14 @@ class TPUEstimator(estimator_lib.Estimator): device_assignment = ctx.device_assignment else: device_assignment = None - tensors_on_cpu = tpu.rewrite_for_inference( - tpu_computation, device_assignment=device_assignment) + + if self._experimental_exported_model_uses_all_cores: + tensors_on_cpu = tpu.rewrite( + tpu_computation, device_assignment=device_assignment) + else: + tensors_on_cpu = tpu.rewrite_for_inference( + tpu_computation, device_assignment=device_assignment) + (estimator_spec, export_outputs_dict, export_outputs_list, predictions_dict) = ( tpu_capture.get()) @@ -2777,8 +2790,12 @@ class TPUEstimator(estimator_lib.Estimator): input_fn = features tpu_init_ops = [] - if ctx.embedding_config: - tpu_init_ops.extend(ctx.embedding_config.tpu_embedding.init_ops) + if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN: + dummy_table_variables, dummy_table_variables_init = ( + tpu_embedding_gradient.create_dummy_table_variables( + ctx.embedding_config.tpu_embedding)) + ctx.embedding_config.dummy_table_variables = dummy_table_variables + tpu_init_ops.append(dummy_table_variables_init) input_holders = _InputPipeline(input_fn, batch_axis, ctx) enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( @@ -3140,6 +3157,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): captured_training_hooks) = ( model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) + @tpu_function.on_device_training_loop def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, [_INITIAL_LOSS]) @@ -3162,6 +3180,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): captured_predict_hooks ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) + @tpu_function.on_device_training_loop def multi_tpu_predict_steps_on_single_shard(): def cond(scalar_stopping_signal): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index d5957b7e8ec40b40c7af8822378cee6134ef0d0f..97fddbc2adb688b3e5ec8c3f39adcebd8db6cbc7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -37,6 +37,86 @@ from tensorflow.python.ops import array_ops from tensorflow.python.util import nest +def partition_or_replicate_on_host(tensor, dims): + """Partitions or replicates the input tensor. + + The ops inside this function are placed on the host side. + + Args: + tensor: The input tensor which will be partioned or replicated. + dims: A list of integer describes how to partition the input tensor. + + Returns: + An iterator of `Tensor`s or a list of partioned tensors. + """ + if dims is None: + return itertools.repeat(tensor) + dims = np.array(dims) + output = [tensor] + shape_list = np.array(tensor.shape.as_list()) + quotients, remainders = np.divmod(shape_list, dims) + for axis, (quotient, remainder, dim, original_size) in enumerate( + zip(quotients, remainders, dims, shape_list)): + if dim <= 1: + continue + if remainder > 0: + # For each dimension, when it cannot be evenly partitioned, XLA assumes + # tensors are partitioned in a greedy manner by using + # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims + # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => + # [[(3, 4), (3, 4), (2, 4), (2, 2)], + # [(2, 4), (2, 4), (2, 4), (2, 2)]] + ceil_ratio = quotient + 1 + num_full_slots, left_over = np.divmod(original_size, ceil_ratio) + num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] + if len(num_or_size_splits) < dim: + num_or_size_splits += [0] * (dim - len(num_or_size_splits)) + new_output = [] + for x in output: + new_output.append( + array_ops.split( + x, num_or_size_splits=num_or_size_splits, axis=axis)) + output = new_output + else: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output + + +def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensor. + + Args: + tensor: The dequeued tensor on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same tensor with the xla_sharding attribute. + """ + if dims is None: + return xla_sharding.replicate(tensor) + elif np.prod(dims) == 1: + return xla_sharding.assign_device(tensor, 0) + else: + tile_assignment = np.arange(np.prod(dims)).reshape(dims) + return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment) + + +def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensors. + + Args: + dequeues: A list of dequeued tensors on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same dequeues with appropriate xla_sharding attribute. + """ + nest.assert_shallow_structure(dequeues, dims) + return nest.map_structure_up_to( + dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims) + + class InfeedQueue(object): """A helper object to build a device infeed queue. @@ -706,7 +786,7 @@ class _PartitionedInfeedQueue(InfeedQueue): with ops.device(tpu.core(tpu_device)): values = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - return self._tag_sharding_attribute_for_dequeued_tensors( + return tag_sharding_attribute_for_dequeued_tensors( values, self._input_partition_dims) def generate_enqueue_ops(self, per_host_sharded_inputs): @@ -758,8 +838,9 @@ class _PartitionedInfeedQueue(InfeedQueue): inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, self._input_partition_dims) inputs_parted_iters = [ - iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in - zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) + iter(self._check_dims_and_partition_or_replicate_on_host(x, dims)) + for x, dims in zip(per_host_sharded_inputs[replica_index], + inputs_part_dims_flat) ] for logical_core in xrange(self._device_assignment.num_cores_per_replica): @@ -789,14 +870,19 @@ class _PartitionedInfeedQueue(InfeedQueue): Args: tensor: Input tensor for partitioning. - dims: 1-D np.array of the list of integer describes how to partition the - input tensor. + dims: A list of integer describes how to partition the input tensor. Raises: ValueError: If the tensor can't be partitioned by dims or the num_cores_per_replica doesn't match the number of partitions(dims.prod()). """ + # No partitioning specified, so don't perform further checks. + if dims is None: + return + + dims = np.array(dims) + if (dims < 1).any(): raise ValueError("All input partition dims must be >= 1.") @@ -817,82 +903,17 @@ class _PartitionedInfeedQueue(InfeedQueue): tensor.shape.assert_is_fully_defined() - def _partition_or_replicate_on_host(self, tensor, dims): - """Partitions or replicates the input tensor. + def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims): + """Checks dims and partitions or replicates the input tensor. The ops inside this function are placed on the host side. Args: tensor: The input tensor which will be partioned or replicated. dims: A list of integer describes how to partition the input tensor. + Returns: An iterator of `Tensor`s or a list of partioned tensors. """ - if dims is None: - return itertools.repeat(tensor) - dims = np.array(dims) self._check_input_partition_dims(tensor, dims) - output = [tensor] - shape_list = np.array(tensor.shape.as_list()) - quotients, remainders = np.divmod(shape_list, dims) - for axis, (quotient, remainder, dim, original_size) in enumerate( - zip(quotients, remainders, dims, shape_list)): - if dim <= 1: - continue - if remainder > 0: - # For each dimension, when it cannot be evenly partitioned, XLA assumes - # tensors are partitioned in a greedy manner by using - # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims - # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => - # [[(3, 4), (3, 4), (2, 4), (2, 2)], - # [(2, 4), (2, 4), (2, 4), (2, 2)]] - ceil_ratio = quotient + 1 - num_full_slots, left_over = np.divmod(original_size, ceil_ratio) - num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] - if len(num_or_size_splits) < dim: - num_or_size_splits += [0] * (dim - len(num_or_size_splits)) - new_output = [] - for x in output: - new_output.append( - array_ops.split( - x, num_or_size_splits=num_or_size_splits, axis=axis)) - output = new_output - else: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output - - def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensor. - - Args: - tensor: The dequeued tensor on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same tensor with the xla_sharding attribute. - """ - if dims is None: - return xla_sharding.replicate(tensor) - elif np.prod(dims) == 1: - return xla_sharding.assign_device(tensor, 0) - else: - tile_assignment = np.arange(np.prod(dims)).reshape(dims) - return xla_sharding.tile( - tensor=tensor, - tile_assignment=tile_assignment) - - def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensors. - - Args: - dequeues: A list of dequeued tensors on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same dequeues with appropriate xla_sharding attribute. - """ - nest.assert_shallow_structure(dequeues, dims) - return nest.map_structure_up_to( - dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, - dims) + return partition_or_replicate_on_host(tensor, dims) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py index 84d5967ea547f0c036f7c9aa936ac0c99c141304..422c7d3b26ffb4ad1b72450c4803ac2eb87cea3b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py @@ -55,3 +55,12 @@ def tpu_shard_context(number_of_shards): def get_tpu_context(): return _current_tpu_context + + +# Decorator function for tpu computation func that was passed to tpu.rewrite() +# if there is an embedded training loop in this func, trace tools will generate +# step markers for each iteration. +def on_device_training_loop(func): + # Value for this attribute is from xla.DebugOptions.StepMarkerLocation. + setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP") + return func diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index 0187b4bec6ecc55943bf48b9268a74e18ea5b488..50848e83f0ef8d999206909ebfe1b0bbc78d1e5b 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.compiler import xla +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import ops @@ -157,10 +158,18 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): # TODO(phawkins): in principle this is too restrictive since it serializes # the training loop steps. In practice it does not matter since this loop # will be compiled by XLA. - return control_flow_ops.tuple(output_tensors, - control_inputs=output_operations) - else: - return output_tensors + output_tensors = control_flow_ops.tuple(output_tensors, + control_inputs=output_operations) + + if tensor_tracer.TensorTracer.is_enabled(): + num_replicas = tpu_function.get_tpu_context().number_of_shards + if num_replicas is None: + num_replicas = 1 + tt = tensor_tracer.TensorTracer() + output_tensors = tt.trace_tpu(ops.get_default_graph(), + output_tensors, None, + num_replicas) + return output_tensors # If the body has arity 0, add a dummy loop-carried value to which we can add # control dependencies from any side-effecting operations. diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 07dbd5ca8d65ec8232d33c016a7369c68a4c9e1f..ada08f95ae46ea06b3896ca3b1603277d62bf6fc 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -22,7 +22,9 @@ cc_library( "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", ], diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index cc242d0e3c9fbf26874d02f3d2ab81fa0dd36584..906e8695cd36722a69810f5e20eb31d92528b554 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -128,7 +128,6 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_compiler_hdrs", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", @@ -147,6 +146,7 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", + "tf_grpc_service_all", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -229,7 +229,7 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS # ones with individual proto_library targets. ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", - "protobuf/checkpointable_object_graph.proto", + "protobuf/trackable_object_graph.proto", "protobuf/control_flow.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", @@ -418,9 +418,8 @@ cc_library( name = "platform_protobuf", srcs = tf_platform_hdrs([ "protobuf.h", - ]) + tf_platform_srcs([ - "protobuf.cc", ]) + [ + "platform/protobuf.cc", "platform/protobuf_util.cc", "lib/core/status.h", ], @@ -439,6 +438,17 @@ cc_library( ], ) +cc_library( + name = "grpc_services", + srcs = [], + hdrs = [ + "platform/grpc_services.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = tf_grpc_service_all(), +) + cc_library( name = "human_readable_json", srcs = tf_platform_srcs(["human_readable_json.cc"]), @@ -664,7 +674,7 @@ cc_library( name = "lib_proto_compiler", hdrs = [ "platform/protobuf_compiler.h", - ] + tf_additional_proto_compiler_hdrs(), + ], copts = tf_copts(), deps = tf_lib_proto_compiler_deps() + [ ":lib_proto_parsing", @@ -1049,13 +1059,13 @@ cc_library( "platform/default/integral_types.h", "platform/default/logging.h", "platform/default/mutex.h", - "platform/default/protobuf.h", "platform/default/thread_annotations.h", "platform/dynamic_annotations.h", "platform/macros.h", "platform/mutex.h", "platform/platform.h", "platform/prefetch.h", + "platform/protobuf.h", "platform/thread_annotations.h", "platform/types.h", "platform/cpu_info.h", @@ -1168,6 +1178,29 @@ tf_gen_op_libs( deps = [":lib"], ) +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + deps = [ + ":lib", + ":lib_proto_parsing", + ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1284,6 +1317,16 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", @@ -1392,7 +1435,7 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( name = "all_kernels_impl", - visibility = ["//visibility:private"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", @@ -1551,6 +1594,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", + ":ops", ":protos_all_cc", ":shape_inference_testutil", ":tensor_testutil", @@ -1897,6 +1941,7 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", + "**/tpu_*", ], ), visibility = ["//visibility:public"], @@ -2282,6 +2327,7 @@ cc_library( "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", + "platform/protobuf.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -2958,6 +3004,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/lower_if_while.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", + "common_runtime/metrics.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", @@ -2969,6 +3016,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", + "common_runtime/ring_alg.h", + "common_runtime/ring_gatherer.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2993,6 +3042,8 @@ tf_cuda_library( "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", "common_runtime/collective_util.cc", + "common_runtime/colocation_graph.cc", + "common_runtime/colocation_graph.h", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -3013,6 +3064,7 @@ tf_cuda_library( "common_runtime/lower_if_while.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_types.cc", + "common_runtime/metrics.cc", "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", @@ -3025,6 +3077,8 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_alg.cc", + "common_runtime/ring_gatherer.cc", "common_runtime/ring_reducer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", @@ -3083,7 +3137,6 @@ tf_cuda_library( ":framework", ":graph", ":lib", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", @@ -3114,15 +3167,6 @@ cc_library( deps = [":lib_internal"], ) -tf_cuda_library( - name = "metrics", - srcs = ["common_runtime/metrics.cc"], - hdrs = ["common_runtime/metrics.h"], - deps = [ - ":lib", - ], -) - tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3139,7 +3183,6 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", @@ -3506,6 +3549,7 @@ tf_cc_tests( "platform/vmodule_benchmark_test.cc", ], deps = [ + ":core_cpu_internal", ":lib", ":lib_internal", ":lib_test_internal", @@ -3936,7 +3980,6 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - ":cudnn_rnn_ops", "//tensorflow/core", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3996,6 +4039,35 @@ tf_cc_tests_gpu( ], ) +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "common_runtime/ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", size = "medium", diff --git a/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d6f28bd022bcd843aa3a7aeb8b1b257a3b3ddfd3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt @@ -0,0 +1,67 @@ +op { + graph_op_name: "AllToAll" + in_arg { + name: "input" + description: <